Add SD model annotation on fly (#869)

* Add SD model annotation on fly

* Move tuned_compile_through_fx to utils

* Fix SD compilation flags
This commit is contained in:
yzhang93
2023-01-26 11:46:36 -08:00
committed by GitHub
parent 9bbffa519e
commit fee73b0b63
8 changed files with 168 additions and 69 deletions

View File

@@ -125,6 +125,7 @@ if __name__ == "__main__":
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
)
clip, unet, vae = mlir_import()

View File

@@ -62,6 +62,7 @@ class SharkifyStableDiffusionModel:
height: int = 512,
batch_size: int = 1,
use_base_vae: bool = False,
use_tuned: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -82,6 +83,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
self.use_tuned = use_tuned
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -133,6 +135,7 @@ class SharkifyStableDiffusionModel:
inputs,
is_f16=is_f16,
model_name=vae_name + self.model_name,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("vae", precision=self.precision),
)
return shark_vae
@@ -172,6 +175,7 @@ class SharkifyStableDiffusionModel:
model_name="unet" + self.model_name,
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
)
return shark_unet

View File

@@ -1,6 +1,6 @@
import os
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import run_cmd, iree_target_map
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
download_model,
download_public_file,
@@ -8,74 +8,95 @@ from shark.shark_downloader import (
)
from shark.parser import shark_args
from stable_args import args
from opt_params import get_params
from utils import set_init_device_flags
set_init_device_flags()
device = (
args.device if "://" not in args.device else args.device.split("://")[0]
)
# Downloads the model (Unet or VAE fp16) from shark_tank
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{args.variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
# Download the model (Unet or VAE fp16) from shark_tank
def load_model_from_tank():
from opt_params import get_params, version, variant
# Downloads the tuned config files from shark_tank
config_bucket = "gs://shark_tank/sd_tuned/configs/"
if args.use_winograd:
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
if args.annotation_model == "unet":
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned"
elif args.annotation_model == "vae":
is_base = "/base" if args.use_base_vae else ""
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}"
bucket, model_name, iree_flags = get_params(
bucket_key, model_key, args.annotation_model, "untuned", args.precision
)
mlir_model, func_name, inputs, golden_out = download_model(
model_name,
tank_url=bucket,
frontend="torch",
)
return mlir_model, model_name
# Download the tuned config files from shark_tank
def load_winograd_configs():
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_name = f"{args.annotation_model}_winograd_{device}.json"
full_gs_url = config_bucket + config_name
winograd_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading Winograd config file from ", winograd_config_dir)
download_public_file(full_gs_url, winograd_config_dir, True)
return winograd_config_dir
if args.annotation_model == "unet" or device == "cuda":
if args.variant in ["anythingv3", "analogdiffusion"]:
def load_lower_configs():
from opt_params import version, variant
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
if variant in ["anythingv3", "analogdiffusion"]:
args.max_length = 77
args.version = "v1_4"
config_version = "v1_4"
if args.annotation_model == "vae":
args.max_length = 77
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}_{device}.json"
config_name = f"{args.annotation_model}_{config_version}_{args.precision}_len{args.max_length}_{device}.json"
full_gs_url = config_bucket + config_name
lowering_config_dir = f"{WORKDIR}configs/" + config_name
print("Loading lowering config file from ", lowering_config_dir)
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir
# Annotate the model with Winograd attribute on selected conv ops
if args.use_winograd:
def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with create_context() as ctx:
winograd_model = model_annotation(
ctx,
input_contents=mlir_model,
input_contents=input_mlir,
config_path=winograd_config_dir,
search_op="conv",
winograd=args.use_winograd,
winograd=True,
)
with open(
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
) as f:
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path
# For Unet annotate the model with tuned lowering configs
if args.annotation_model == "unet" or device == "cuda":
if args.use_winograd:
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
if use_winograd:
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
else:
input_mlir = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
dump_after = "iree-flow-pad-linalg-ops"
# Dump IR after padding/img2col/winograd passes
@@ -90,6 +111,8 @@ if args.annotation_model == "unet" or device == "cuda":
device_spec_args = (
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
)
print("Applying tuned configs on", model_name)
run_cmd(
f"iree-compile {input_mlir} "
"--iree-input-type=tm_tensor "
@@ -116,7 +139,48 @@ if args.annotation_model == "unet" or device == "cuda":
# Remove the intermediate mlir and save the final annotated model
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
with open(output_path, "w") as f:
if model_name.split("_")[-1] != "tuned":
out_file_path = (
f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
def sd_model_annotation(mlir_model, model_name):
if args.annotation_model == "unet" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
winograd_model, model_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
model_path, lowering_config_dir, model_name, use_winograd
)
elif args.annotation_model == "vae" and device == "vulkan":
use_winograd = True
winograd_config_dir = load_winograd_configs()
tuned_model, output_path = annotate_with_winograd(
mlir_model, winograd_config_dir, model_name
)
else:
use_winograd = False
lowering_config_dir = load_lower_configs()
tuned_model, output_path = annotate_with_lower_configs(
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
if __name__ == "__main__":
mlir_model, model_name = load_model_from_tank()
if device == "cuda":
mlir_model = f"{WORKDIR}{model_name}_torch/{model_name}_torch.mlir"
sd_model_annotation(mlir_model, model_name)

View File

@@ -93,7 +93,7 @@ p.add_argument(
p.add_argument(
"--import_mlir",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
)
@@ -299,11 +299,4 @@ p.add_argument(
help="Options are unet and vae.",
)
p.add_argument(
"--use_winograd",
default=False,
action=argparse.BooleanOptionalAction,
help="Apply Winograd on selected conv ops.",
)
args = p.parse_args()

View File

@@ -1,4 +1,5 @@
import os
import gc
import torch
from shark.shark_inference import SharkInference
from stable_args import args
@@ -9,6 +10,7 @@ from shark.iree_utils.vulkan_utils import (
)
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from resources import opt_flags
from sd_annotation import sd_model_annotation
import sys
@@ -70,12 +72,40 @@ def compile_through_fx(
model_name,
is_f16=False,
f16_input_mask=None,
use_tuned=False,
extra_args=[],
):
mlir_module, func_name = import_with_fx(
model, inputs, is_f16, f16_input_mask
model, inputs, is_f16, f16_input_mask, return_str=use_tuned
)
if use_tuned:
model_name = model_name + "_tuned"
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
if "cuda" in args.device:
output_path = (
f"{args.annotation_output}/{model_name}_orig.mlir"
)
with open(output_path, "w") as f:
f.write(mlir_module)
f.close()
mlir_module = output_path
tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
shark_module = SharkInference(
mlir_module,
device=args.device,
@@ -202,36 +232,30 @@ def set_init_device_flags():
elif args.hf_model_id == "prompthero/openjourney":
args.max_length = 64
# Use tuned models in the case of stablediffusion/fp16 and rdna3 cards.
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
or args.precision != "fp16"
or "vulkan" not in args.device
or "rdna3" not in args.iree_vulkan_target_triple
or ("vulkan" not in args.device and "cuda" not in args.device)
):
args.use_tuned = False
elif (
"vulkan" in args.device
and "rdna3" not in args.iree_vulkan_target_triple
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
]:
args.use_tuned = False
# Use tuned model in the case of stablediffusion/fp16 and cuda device sm_80
if (
args.hf_model_id
in [
"stabilityai/stable-diffusion-2-1-base",
"Linaqruf/anything-v3.0",
"wavymulder/Analog-Diffusion",
]
and args.precision == "fp16"
and "cuda" in args.device
and get_cuda_sm_cc() in ["sm_80", "sm_89"]
):
args.use_tuned = True
if args.use_tuned:
print(f"Using {args.device} tuned models for stablediffusion/fp16.")
else:
@@ -287,6 +311,11 @@ def get_opt_flags(model, precision="fp16"):
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
"default_compilation_flags"
]
if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]:
device = (
args.device
@@ -303,7 +332,6 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
return iree_flags