Update rocm arg handling in SD utils

This commit is contained in:
Ean Garvey
2023-08-16 13:23:37 -05:00
committed by GitHub
parent c9cdc8f3c7
commit 7d77d6cfb2

View File

@@ -25,7 +25,7 @@ from shark.iree_utils.vulkan_utils import (
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
@@ -63,8 +63,6 @@ def _load_vmfb(shark_module, vmfb_path, model, precision):
def _compile_module(shark_module, model_name, extra_args=[]):
if args.iree_rocm_bc_dir is not None:
extra_args.append(f"-iree-rocm-bc-dir={args.iree_rocm_bc_dir}")
if args.load_vmfb or args.save_vmfb:
vmfb_path = get_vmfb_path_name(model_name)
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
@@ -503,9 +501,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if args.iree_rocm_bc_dir is not None:
iree_flags.append(f"-iree-rocm-bc-dir={args.iree_rocm_bc_dir}")
iree_flags.append("-iree-rocm-link-bc=True")
if "rocm" in args.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
print(iree_flags)
if args.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(