mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD][WEB] Avoid passing args to utils APIs
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
This commit is contained in:
@@ -18,7 +18,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ p.add_argument(
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="cpu", help="device to run the model."
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
@@ -98,7 +98,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(1, 100, value=50, step=1, label="Steps")
|
||||
guidance = gr.Slider(
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=7.5,
|
||||
@@ -156,7 +156,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
guidance,
|
||||
guidance_scale,
|
||||
seed,
|
||||
scheduler_key,
|
||||
],
|
||||
@@ -168,7 +168,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
guidance,
|
||||
guidance_scale,
|
||||
seed,
|
||||
scheduler_key,
|
||||
],
|
||||
|
||||
@@ -40,12 +40,8 @@ schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
# set use_tuned
|
||||
if "rdna3" not in get_vulkan_triple_flag():
|
||||
args.use_tuned = False
|
||||
|
||||
# set iree-runtime flags
|
||||
set_iree_runtime_flags(args)
|
||||
set_iree_runtime_flags()
|
||||
|
||||
cache_obj = dict()
|
||||
# cache vae, unet and clip.
|
||||
@@ -53,7 +49,7 @@ cache_obj = dict()
|
||||
cache_obj["vae"],
|
||||
cache_obj["unet"],
|
||||
cache_obj["clip"],
|
||||
) = (get_vae(args), get_unet(args), get_clip(args))
|
||||
) = (get_vae(), get_unet(), get_clip())
|
||||
|
||||
# cache tokenizer
|
||||
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
|
||||
@@ -11,22 +11,19 @@ import numpy as np
|
||||
import time
|
||||
|
||||
|
||||
def set_ui_params(
|
||||
prompt, negative_prompt, steps, guidance, seed, scheduler_key
|
||||
):
|
||||
def set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed):
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.steps = steps
|
||||
args.guidance = guidance
|
||||
args.guidance_scale = guidance_scale
|
||||
args.seed = seed
|
||||
args.scheduler = scheduler_key
|
||||
|
||||
|
||||
def stable_diff_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
scheduler_key: str,
|
||||
):
|
||||
@@ -37,14 +34,20 @@ def stable_diff_inf(
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
|
||||
set_ui_params(
|
||||
prompt, negative_prompt, steps, guidance, seed, scheduler_key
|
||||
)
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
set_ui_params(prompt, negative_prompt, steps, guidance_scale, seed)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
generator = torch.manual_seed(
|
||||
args.seed
|
||||
) # Seed generator to create the inital latent noise
|
||||
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
|
||||
|
||||
# set height and width.
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
if args.version == "v2.1":
|
||||
height = 768
|
||||
width = 768
|
||||
|
||||
# Initialize vae and unet models.
|
||||
vae, unet, clip, tokenizer = (
|
||||
cache_obj["vae"],
|
||||
@@ -52,7 +55,7 @@ def stable_diff_inf(
|
||||
cache_obj["clip"],
|
||||
cache_obj["tokenizer"],
|
||||
)
|
||||
scheduler = schedulers[args.scheduler]
|
||||
scheduler = schedulers[scheduler_key]
|
||||
|
||||
start = time.time()
|
||||
text_input = tokenizer(
|
||||
@@ -84,7 +87,7 @@ def stable_diff_inf(
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents = torch.randn(
|
||||
(1, 4, args.height // 8, args.width // 8),
|
||||
(1, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
@@ -109,7 +112,7 @@ def stable_diff_inf(
|
||||
latents_numpy,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
args.guidance_scale,
|
||||
)
|
||||
)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
@@ -136,7 +139,7 @@ def stable_diff_inf(
|
||||
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
|
||||
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, scheduler={scheduler_key}, seed={args.seed}, size={height}x{width}, version={args.version}"
|
||||
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
|
||||
print(f"\nAverage step time: {avg_ms}ms/it")
|
||||
text_output += "\nTotal image generation time: {0:.2f}sec".format(
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from models.stable_diffusion.utils import compile_through_fx
|
||||
from models.stable_diffusion.stable_args import args
|
||||
import torch
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v2.1": "stabilityai/stable-diffusion-2-1",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
model_input = {
|
||||
"v2": {
|
||||
"v2.1": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 96, 96),),
|
||||
"unet": (
|
||||
@@ -42,13 +43,16 @@ model_input = {
|
||||
},
|
||||
}
|
||||
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
if args.version == "v2":
|
||||
if args.version != "v1.4":
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_config[args.version], subfolder="text_encoder"
|
||||
)
|
||||
@@ -63,7 +67,6 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
args,
|
||||
clip_model,
|
||||
model_input[args.version]["clip"],
|
||||
model_name=model_name,
|
||||
@@ -72,10 +75,7 @@ def get_clip_mlir(args, model_name="clip_text", extra_args=[]):
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(args, model_name="vae", extra_args=[]):
|
||||
# revision param for from_pretrained defaults to "main" => fp32
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -102,7 +102,6 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
|
||||
inputs = model_input[args.version]["vae"]
|
||||
|
||||
shark_vae = compile_through_fx(
|
||||
args,
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
@@ -111,7 +110,7 @@ def get_vae_mlir(args, model_name="vae", extra_args=[]):
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
|
||||
def get_vae_encode_mlir(model_name="vae_encode", extra_args=[]):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -131,7 +130,6 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
|
||||
[inputs.half().cuda() for inputs in model_input[args.version]["vae"]]
|
||||
)
|
||||
shark_vae = compile_through_fx(
|
||||
args,
|
||||
vae,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
@@ -140,9 +138,7 @@ def get_vae_encode_mlir(args, model_name="vae_encode", extra_args=[]):
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(args, model_name="unet", extra_args=[]):
|
||||
model_revision = "fp16" if args.precision == "fp16" else "main"
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -178,7 +174,6 @@ def get_unet_mlir(args, model_name="unet", extra_args=[]):
|
||||
else:
|
||||
inputs = model_input[args.version]["unet"]
|
||||
shark_unet = compile_through_fx(
|
||||
args,
|
||||
unet,
|
||||
inputs,
|
||||
model_name=model_name,
|
||||
|
||||
@@ -5,15 +5,23 @@ from models.stable_diffusion.model_wrappers import (
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from models.stable_diffusion.utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
def get_unet(args):
|
||||
|
||||
def get_unet():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
# Tuned model is present for `fp16` precision.
|
||||
if args.precision == "fp16":
|
||||
if args.use_tuned:
|
||||
@@ -22,20 +30,22 @@ def get_unet(args):
|
||||
model_name = "unet_1dec_fp16_tuned"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16_tuned"
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
if args.version == "v2.1":
|
||||
model_name = "unet2_14dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(args, model_name, iree_flags)
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
# Tuned model is not present for `fp32` case.
|
||||
if args.precision == "fp32":
|
||||
@@ -47,29 +57,49 @@ def get_unet(args):
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(args, model_name, iree_flags)
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_unet_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "int8":
|
||||
bucket = "gs://shark_tank/prashant_nod"
|
||||
model_name = "unet_int8"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
]
|
||||
sys.exit("int8 model is currently in maintenance.")
|
||||
# # TODO: Pass iree_flags to the exported model.
|
||||
# if args.import_mlir:
|
||||
# sys.exit(
|
||||
# "--import_mlir is not supported for the int8 model, try --no-import_mlir flag."
|
||||
# )
|
||||
# return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae(args):
|
||||
def get_vae():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_8dec_fp16"
|
||||
if args.version == "v2.1":
|
||||
model_name = "vae2_14dec_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(args, model_name, iree_flags)
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
if args.precision == "fp32":
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
@@ -80,23 +110,25 @@ def get_vae(args):
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(args, model_name, iree_flags)
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_vae_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode(args):
|
||||
def get_vae_encode():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if args.precision == "fp16":
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_encode_1dec_fp16"
|
||||
if args.version == "v2":
|
||||
model_name = "vae2_encode_29nov_fp16"
|
||||
iree_flags += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=32",
|
||||
"--iree-flow-enable-conv-img2col-transform",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_vae_encode_mlir(model_name, iree_flags)
|
||||
@@ -115,20 +147,25 @@ def get_vae_encode(args):
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip(args):
|
||||
def get_clip():
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_8dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "clip2base_8dec_fp32"
|
||||
if args.version == "v2.1":
|
||||
model_name = "clip2_14dec_fp32"
|
||||
iree_flags += [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(args, model_name, iree_flags)
|
||||
return get_shark_model(args, bucket, model_name, iree_flags)
|
||||
return get_clip_mlir(model_name, iree_flags)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
@@ -4,6 +4,10 @@ p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
@@ -14,14 +18,10 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=["trees, green"],
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
@@ -29,13 +29,6 @@ p.add_argument(
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="v2.1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@@ -43,20 +36,6 @@ p.add_argument(
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the height to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="the width to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
@@ -65,10 +44,29 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=77,
|
||||
help="max length of the tokenizer output.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="EulerDiscrete",
|
||||
help="can be [PNDM, LMSDiscrete, DDIM, DPMSolverMultistep, EulerDiscrete]",
|
||||
default="v2.1base",
|
||||
help="Specify version of stable diffusion model",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -79,23 +77,30 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=77,
|
||||
help="max length of the tokenizer output.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--cache",
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
@@ -104,16 +109,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_tuned",
|
||||
default=True,
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Download and use the tuned version of the model if available",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -122,4 +121,46 @@ p.add_argument(
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--local_tank_cache",
|
||||
default="",
|
||||
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dump_isa",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="When enabled call amdllpc to get ISA dumps. use with dispatch benchmarks.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks",
|
||||
default=None,
|
||||
help='dispatches to return benchamrk data on. use "All" for all, and None for none.',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--dispatch_benchmarks_dir",
|
||||
default="temp_dispatch_benchmarks",
|
||||
help='directory where you want to store dispatch data generated with "--dispatch_benchmarks"',
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_rgp",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for inserting debug frames between iterations for use with rgp.",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -2,41 +2,43 @@ import os
|
||||
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from models.stable_diffusion.stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
|
||||
|
||||
def set_iree_runtime_flags(args):
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
]
|
||||
if "vulkan" in args.device:
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _compile_module(args, shark_module, model_name, extra_args=[]):
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
if args.cache:
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if os.path.isfile(vmfb_path):
|
||||
print("Loading flatbuffer from {}".format(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
return shark_module
|
||||
print("No vmfb found. Compiling and saving to {}".format(vmfb_path))
|
||||
path = shark_module.save_module(os.getcwd(), extended_name, extra_args)
|
||||
shark_module.load_module(path)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(args, tank_url, model_name, extra_args=[]):
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -44,16 +46,18 @@ def get_shark_model(args, tank_url, model_name, extra_args=[]):
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name, tank_url=tank_url, frontend="torch"
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(args, shark_module, model_name, extra_args)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into shark_module.
|
||||
def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(model, inputs, model_name, extra_args=[]):
|
||||
|
||||
mlir_module, func_name = import_with_fx(model, inputs)
|
||||
|
||||
@@ -64,4 +68,21 @@ def compile_through_fx(args, model, inputs, model_name, extra_args=[]):
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(args, shark_module, model_name, extra_args)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if "vulkan" in args.device:
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user