mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-14 16:28:01 -05:00
132 lines
3.8 KiB
Python
132 lines
3.8 KiB
Python
import os
|
|
from pathlib import Path
|
|
from shark_tuner.codegen_tuner import SharkCodegenTuner
|
|
from shark_tuner.iree_utils import (
|
|
dump_dispatches,
|
|
create_context,
|
|
export_module_to_mlir_file,
|
|
)
|
|
from shark_tuner.model_annotation import model_annotation
|
|
from apps.stable_diffusion.src.utils.stable_args import args
|
|
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
|
|
from apps.stable_diffusion.src.utils.sd_annotation import (
|
|
get_device_args,
|
|
load_winograd_configs,
|
|
)
|
|
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
|
|
|
|
|
def load_mlir_module():
|
|
if "upscaler" in args.hf_model_id:
|
|
is_upscaler = True
|
|
else:
|
|
is_upscaler = False
|
|
sd_model = SharkifyStableDiffusionModel(
|
|
args.hf_model_id,
|
|
args.ckpt_loc,
|
|
args.custom_vae,
|
|
args.precision,
|
|
max_len=args.max_length,
|
|
batch_size=args.batch_size,
|
|
height=args.height,
|
|
width=args.width,
|
|
use_base_vae=args.use_base_vae,
|
|
is_upscaler=is_upscaler,
|
|
use_tuned=False,
|
|
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
|
return_mlir=True,
|
|
)
|
|
|
|
if args.annotation_model == "unet":
|
|
mlir_module = sd_model.unet()
|
|
model_name = sd_model.model_name["unet"]
|
|
elif args.annotation_model == "vae":
|
|
mlir_module = sd_model.vae()
|
|
model_name = sd_model.model_name["vae"]
|
|
else:
|
|
raise ValueError(
|
|
f"{args.annotation_model} is not supported for tuning."
|
|
)
|
|
|
|
return mlir_module, model_name
|
|
|
|
|
|
def main():
|
|
args.use_tuned = False
|
|
set_init_device_flags()
|
|
mlir_module, model_name = load_mlir_module()
|
|
|
|
# Get device and device specific arguments
|
|
device, device_spec_args = get_device_args()
|
|
device_spec = ""
|
|
vulkan_target_triple = ""
|
|
if device_spec_args:
|
|
device_spec = device_spec_args[-1].split("=")[-1].strip()
|
|
if device == "vulkan":
|
|
vulkan_target_triple = device_spec
|
|
device_spec = device_spec.split("-")[0]
|
|
|
|
# Add winograd annotation for vulkan device
|
|
use_winograd = (
|
|
True
|
|
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
|
else False
|
|
)
|
|
winograd_config = (
|
|
load_winograd_configs()
|
|
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
|
else ""
|
|
)
|
|
with create_context() as ctx:
|
|
input_module = model_annotation(
|
|
ctx,
|
|
input_contents=mlir_module,
|
|
config_path=winograd_config,
|
|
search_op="conv",
|
|
winograd=use_winograd,
|
|
)
|
|
|
|
# Dump model dispatches
|
|
generates_dir = Path.home() / "tmp"
|
|
if not os.path.exists(generates_dir):
|
|
os.makedirs(generates_dir)
|
|
dump_mlir = generates_dir / "temp.mlir"
|
|
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
|
|
export_module_to_mlir_file(input_module, dump_mlir)
|
|
dump_dispatches(
|
|
dump_mlir,
|
|
device,
|
|
dispatch_dir,
|
|
vulkan_target_triple,
|
|
use_winograd=use_winograd,
|
|
)
|
|
|
|
# Tune each dispatch
|
|
dtype = "f16" if args.precision == "fp16" else "f32"
|
|
config_filename = f"{model_name}_{device_spec}_configs.json"
|
|
|
|
for f_path in os.listdir(dispatch_dir):
|
|
if not f_path.endswith(".mlir"):
|
|
continue
|
|
|
|
model_dir = os.path.join(dispatch_dir, f_path)
|
|
|
|
tuner = SharkCodegenTuner(
|
|
model_dir,
|
|
device,
|
|
"random",
|
|
args.num_iters,
|
|
args.tuned_config_dir,
|
|
dtype,
|
|
args.search_op,
|
|
batch_size=1,
|
|
config_filename=config_filename,
|
|
use_dispatch=True,
|
|
vulkan_target_triple=vulkan_target_triple,
|
|
)
|
|
tuner.tune()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|