add cli args for vulkan target triple

This commit is contained in:
PhaneeshB
2022-10-24 11:47:49 +05:30
committed by Prashant Kumar
parent 9956099516
commit fd578a48a9
4 changed files with 36 additions and 26 deletions

View File

@@ -19,10 +19,17 @@ VAE_FP16 = "vae_fp16"
VAE_FP32 = "vae_fp32"
UNET_FP16 = "unet_fp16"
UNET_FP32 = "unet_fp32"
IREE_EXTRA_ARGS = []
def get_models():
global IREE_EXTRA_ARGS
if args.precision == "fp16":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir == True:
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
model_name=UNET_FP16
@@ -31,22 +38,19 @@ def get_models():
return get_shark_model(
GCLOUD_BUCKET,
VAE_FP16,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
IREE_EXTRA_ARGS,
), get_shark_model(
GCLOUD_BUCKET,
UNET_FP16,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
IREE_EXTRA_ARGS,
)
elif args.precision == "fp32":
IREE_EXTRA_ARGS += [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
]
if args.import_mlir == True:
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
model_name=UNET_FP32
@@ -55,25 +59,21 @@ def get_models():
return get_shark_model(
GCLOUD_BUCKET,
VAE_FP32,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
IREE_EXTRA_ARGS,
), get_shark_model(
GCLOUD_BUCKET,
UNET_FP32,
[
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
],
IREE_EXTRA_ARGS,
)
if __name__ == "__main__":
dtype = torch.float32 if args.precision == "fp32" else torch.half
if len(args.iree_vulkan_target_triple) > 0:
IREE_EXTRA_ARGS.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
prompt = [args.prompt]

View File

@@ -64,5 +64,11 @@ p.add_argument(
help="saves the compiled flatbuffer to the local directory",
)
p.add_argument(
"--iree-vulkan-target-triple",
type=str,
default="",
help="Specify target triple for vulkan",
)
args = p.parse_args()

View File

@@ -18,7 +18,7 @@ import numpy as np
import os
# Get the iree-compile arguments given device.
def get_iree_device_args(device):
def get_iree_device_args(device, extra_args=[]):
if device == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
@@ -30,7 +30,7 @@ def get_iree_device_args(device):
if device in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args()
return get_iree_vulkan_args(extra_args=extra_args)
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
@@ -68,7 +68,7 @@ def compile_module_to_flatbuffer(
# Setup Compile arguments wrt to frontends.
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args()
args += extra_args

View File

@@ -17,7 +17,11 @@
from shark.iree_utils._common import run_cmd
def get_vulkan_triple_flag():
def get_vulkan_triple_flag(extra_args=[]):
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
print(f"Using target triple from command line args")
return None
vulkan_device_cmd = "vulkaninfo | grep deviceName"
vulkan_device = run_cmd(vulkan_device_cmd).strip()
if all(x in vulkan_device for x in ("Apple", "M1")):
@@ -52,10 +56,10 @@ def get_vulkan_triple_flag():
return None
def get_iree_vulkan_args():
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
vulkan_flag = []
vulkan_triple_flag = get_vulkan_triple_flag()
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag