mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update rocm arg handling in SD utils
This commit is contained in:
@@ -25,7 +25,7 @@ from shark.iree_utils.vulkan_utils import (
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.resources import opt_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
@@ -63,8 +63,6 @@ def _load_vmfb(shark_module, vmfb_path, model, precision):
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.iree_rocm_bc_dir is not None:
|
||||
extra_args.append(f"-iree-rocm-bc-dir={args.iree_rocm_bc_dir}")
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
vmfb_path = get_vmfb_path_name(model_name)
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
@@ -503,9 +501,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
if args.iree_rocm_bc_dir is not None:
|
||||
iree_flags.append(f"-iree-rocm-bc-dir={args.iree_rocm_bc_dir}")
|
||||
iree_flags.append("-iree-rocm-link-bc=True")
|
||||
if "rocm" in args.device:
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
print(iree_flags)
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
|
||||
Reference in New Issue
Block a user