Disable passing of sm_arch to iree-compile CL args by default. (#253)

* Disable passing of sm_arch to iree-compile CL args by default.

* Fix formatting.
This commit is contained in:
Ean Garvey
2022-08-10 03:19:24 -05:00
committed by GitHub
parent f7f24dc4d9
commit 23619068eb
2 changed files with 6 additions and 9 deletions

View File

@@ -16,6 +16,7 @@
import iree.runtime as ireert
import ctypes
from shark.parser import shark_args
# Get the default gpu args given the architecture.
def get_iree_gpu_args():
@@ -23,7 +24,9 @@ def get_iree_gpu_args():
ireert.flags.parse_flags("--cuda_allow_inline_execution")
# TODO: Give the user_interface to pass the sm_arch.
sm_arch = get_cuda_sm_cc()
if sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]:
if (
sm_arch in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86"]
) and (shark_args.enable_tf32 == True):
return [
"--iree-hal-cuda-disable-loop-nounroll-wa",
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",

View File

@@ -47,16 +47,10 @@ parser.add_argument(
default="./shark_tmp",
)
parser.add_argument(
"--save_mlir",
"--enable_tf32",
default=False,
action="store_true",
help="Saves input MLIR module to /tmp/ directory.",
)
parser.add_argument(
"--save_vmfb",
default=False,
action="store_true",
help="Saves iree .vmfb module to /tmp/ directory.",
help="Enables TF32 precision calculations on supported GPUs.",
)
parser.add_argument(
"--model_config_path",