mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add script to auto annotate SD models and variants (#751)
* Add script to auto annotate SD models and variants * Add model config files * Add script to auto annotate SD models and variants * Add model config files * Move config files to shark_tank
This commit is contained in:
108
shark/examples/shark_inference/stable_diffusion/sd_annotation.py
Normal file
108
shark/examples/shark_inference/stable_diffusion/sd_annotation.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
from shark.model_annotation import model_annotation, create_context
|
||||
from shark.iree_utils._common import run_cmd, iree_target_map
|
||||
from shark.shark_downloader import (
|
||||
download_model,
|
||||
download_public_file,
|
||||
WORKDIR,
|
||||
)
|
||||
from shark.parser import shark_args
|
||||
from stable_args import args
|
||||
from opt_params import get_params
|
||||
from utils import set_init_device_flags
|
||||
|
||||
|
||||
# Downloads the model (Unet or VAE fp16) from shark_tank
|
||||
set_init_device_flags()
|
||||
shark_args.local_tank_cache = args.local_tank_cache
|
||||
bucket_key = f"{args.variant}/untuned"
|
||||
use_winograd = False
|
||||
if args.annotation_model == "unet":
|
||||
if args.version == "v2_1base":
|
||||
use_winograd = True
|
||||
model_key = f"{args.variant}/{args.version}/unet/{args.precision}/length_{args.max_length}/untuned"
|
||||
elif args.annotation_model == "vae":
|
||||
use_winograd = True
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
model_key = f"{args.variant}/{args.version}/vae/{args.precision}/length_77/untuned{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, args.annotation_model, "untuned", args.precision
|
||||
)
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=bucket,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
# Downloads the tuned config files from shark_tank
|
||||
config_bucket = "gs://shark_tank/sd_tuned/configs/"
|
||||
if use_winograd:
|
||||
config_name = f"{args.annotation_model}_winograd.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
winograd_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, winograd_config_dir, True)
|
||||
|
||||
if args.annotation_model == "unet":
|
||||
if args.variant in ["anythingv3", "analogdiffusion"]:
|
||||
args.max_length = 77
|
||||
config_name = f"{args.annotation_model}_{args.version}_{args.precision}_len{args.max_length}.json"
|
||||
full_gs_url = config_bucket + config_name
|
||||
lowering_config_dir = f"{WORKDIR}configs/" + config_name
|
||||
download_public_file(full_gs_url, lowering_config_dir, True)
|
||||
|
||||
# Annotate the model with Winograd attribute on selected conv ops
|
||||
if use_winograd:
|
||||
with create_context() as ctx:
|
||||
winograd_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_model,
|
||||
config_path=winograd_config_dir,
|
||||
search_op="conv",
|
||||
winograd=use_winograd,
|
||||
)
|
||||
with open(
|
||||
f"{args.annotation_output}/{model_name}_tuned_torch.mlir", "w"
|
||||
) as f:
|
||||
f.write(str(winograd_model))
|
||||
|
||||
# For Unet annotate the model with tuned lowering configs
|
||||
if args.annotation_model == "unet":
|
||||
if use_winograd:
|
||||
input_mlir = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
|
||||
else:
|
||||
input_mlir = f"{WORKDIR}/{model_name}_torch/{model_name}_torch.mlir"
|
||||
dump_after = "iree-flow-pad-linalg-ops"
|
||||
|
||||
# Dump IR after padding/img2col/winograd passes
|
||||
run_cmd(
|
||||
f"iree-compile {input_mlir} "
|
||||
"--iree-input-type=tm_tensor "
|
||||
f"--iree-hal-target-backends={iree_target_map(args.device)} "
|
||||
f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} "
|
||||
"--iree-stream-resource-index-bits=64 "
|
||||
"--iree-vm-target-index-bits=64 "
|
||||
"--iree-flow-enable-padding-linalg-ops "
|
||||
"--iree-flow-linalg-ops-padding-size=32 "
|
||||
"--iree-flow-enable-conv-img2col-transform "
|
||||
f"--mlir-print-ir-after={dump_after} "
|
||||
"--compile-to=flow "
|
||||
f"2>{args.annotation_output}/dump_after_winograd.mlir "
|
||||
)
|
||||
|
||||
# Annotate the model with lowering configs in the config file
|
||||
with create_context() as ctx:
|
||||
tuned_model = model_annotation(
|
||||
ctx,
|
||||
input_contents=f"{args.annotation_output}/dump_after_winograd.mlir",
|
||||
config_path=lowering_config_dir,
|
||||
search_op="all",
|
||||
)
|
||||
|
||||
# Remove the intermediate mlir and save the final annotated model
|
||||
os.remove(f"{args.annotation_output}/dump_after_winograd.mlir")
|
||||
output_path = f"{args.annotation_output}/{model_name}_tuned_torch.mlir"
|
||||
with open(output_path, "w") as f:
|
||||
f.write(str(tuned_model))
|
||||
print(f"Saved the annotated mlir in {output_path}.")
|
||||
@@ -1,4 +1,10 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@@ -223,4 +229,22 @@ p.add_argument(
|
||||
help="flag for removing the pregress bar animation during image generation",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### SD model auto-annotation flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_output",
|
||||
type=path_expand,
|
||||
default="./",
|
||||
help="Directory to save the annotated mlir file",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--annotation_model",
|
||||
type=str,
|
||||
default="unet",
|
||||
help="Options are unet and vae.",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -40,17 +40,23 @@ def model_annotation(
|
||||
input_contents: str,
|
||||
config_path: str,
|
||||
search_op: str,
|
||||
winograd: bool = False,
|
||||
):
|
||||
if os.path.isfile(input_contents):
|
||||
with open(input_contents, "rb") as f:
|
||||
input_contents = f.read()
|
||||
module = ir.Module.parse(input_contents)
|
||||
|
||||
configs = load_model_configs(config_path)
|
||||
if winograd:
|
||||
with open(config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
configs = data["c,f"]
|
||||
else:
|
||||
configs = load_model_configs(config_path)
|
||||
|
||||
# The Python API does not expose a general walk() function, so we just
|
||||
# do it ourselves.
|
||||
walk_children(module.operation, configs, search_op)
|
||||
walk_children(module.operation, configs, search_op, winograd)
|
||||
|
||||
if not module.operation.verify():
|
||||
raise RuntimeError("Modified program does not verify!")
|
||||
@@ -92,7 +98,9 @@ def load_model_configs(config_path: str):
|
||||
return config
|
||||
|
||||
|
||||
def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
|
||||
def walk_children(
|
||||
op: ir.Operation, configs: List[Dict], search_op: str, winograd: bool
|
||||
):
|
||||
if search_op == "matmul":
|
||||
op_names = ["linalg.matmul", "mhlo.dot"]
|
||||
elif search_op == "bmm":
|
||||
@@ -121,6 +129,11 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
|
||||
# 'operation' and 'name' attributes.
|
||||
if isinstance(child_op, ir.OpView):
|
||||
child_op = child_op.operation
|
||||
if winograd and child_op.name in [
|
||||
"linalg.conv_2d_nchw_fchw",
|
||||
"linalg.conv_2d_nhwc_hwcf",
|
||||
]:
|
||||
add_winograd_attribute(child_op, configs)
|
||||
if child_op.name in op_names:
|
||||
if child_op.name == "linalg.generic":
|
||||
# This is for generic op that has contractionOpInterface
|
||||
@@ -151,7 +164,7 @@ def walk_children(op: ir.Operation, configs: List[Dict], search_op: str):
|
||||
)
|
||||
print(f"Updated op {child_op}", file=sys.stderr)
|
||||
|
||||
walk_children(child_op, configs, search_op)
|
||||
walk_children(child_op, configs, search_op, winograd)
|
||||
|
||||
|
||||
def get_op_shape(op: ir.Operation, search_op: str):
|
||||
@@ -294,10 +307,6 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
pipeline_depth = config["pipeline_depth"]
|
||||
if "split_k" in config.keys():
|
||||
split_k = config["split_k"]
|
||||
if "devices" in config.keys():
|
||||
devices = config["devices"]
|
||||
if "shard_sizes" in config.keys():
|
||||
shard_sizes = config["shard_sizes"]
|
||||
elif "SPIRV" in config["pipeline"]:
|
||||
pipeline = config["pipeline"]
|
||||
tile_sizes = [
|
||||
@@ -355,6 +364,39 @@ def add_attributes(op: ir.Operation, config: List[Dict]):
|
||||
add_attribute_by_name(op, "iree_flow_split_k", split_k)
|
||||
|
||||
|
||||
def add_winograd_attribute(op: ir.Operation, config: List):
|
||||
op_result = str(op.results[0]).split("ins(")[1]
|
||||
dilation = int(
|
||||
str(op.attributes["dilations"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
stride = int(
|
||||
str(op.attributes["strides"]).split("dense<")[1].split(">")[0]
|
||||
)
|
||||
|
||||
if op.name == "linalg.conv_2d_nchw_fchw":
|
||||
f = int(op_result.split("tensor<")[2].split("x")[0])
|
||||
c = int(op_result.split("tensor<")[2].split("x")[1])
|
||||
kh = int(op_result.split("tensor<")[2].split("x")[2])
|
||||
kw = int(op_result.split("tensor<")[2].split("x")[3])
|
||||
else:
|
||||
kh = int(op_result.split("tensor<")[2].split("x")[0])
|
||||
kw = int(op_result.split("tensor<")[2].split("x")[1])
|
||||
c = int(op_result.split("tensor<")[2].split("x")[2])
|
||||
f = int(op_result.split("tensor<")[2].split("x")[3])
|
||||
|
||||
if (
|
||||
dilation == 1
|
||||
and stride == 1
|
||||
and kh == 3
|
||||
and kw == 3
|
||||
and [c, f] in config
|
||||
):
|
||||
op.attributes["iree_winograd_conv"] = ir.IntegerAttr.get(
|
||||
ir.IntegerType.get_signless(64), 1
|
||||
)
|
||||
print("Apply Winograd on selected conv op: ", op)
|
||||
|
||||
|
||||
def add_attribute_by_name(op: ir.Operation, name: str, val: int):
|
||||
attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), val)
|
||||
op.attributes[name] = attr
|
||||
|
||||
Reference in New Issue
Block a user