mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Disable passing of sm_arch to iree-compile CL args by default. (#253)
* Disable passing of sm_arch to iree-compile CL args by default. * Fix formatting.
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_gpu_args():
|
||||
@@ -23,7 +24,9 @@ def get_iree_gpu_args():
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
# TODO: Give the user_interface to pass the sm_arch.
|
||||
sm_arch = get_cuda_sm_cc()
|
||||
if sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]:
|
||||
if (
|
||||
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
"--iree-hal-cuda-disable-loop-nounroll-wa",
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
|
||||
@@ -47,16 +47,10 @@ parser.add_argument(
|
||||
default="./shark_tmp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_mlir",
|
||||
"--enable_tf32",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves input MLIR module to /tmp/ directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Saves iree .vmfb module to /tmp/ directory.",
|
||||
help="Enables TF32 precision calculations on supported GPUs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config_path",
|
||||
|
||||
Reference in New Issue
Block a user