mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
add cli args for vulkan target triple
This commit is contained in:
committed by
Prashant Kumar
parent
9956099516
commit
fd578a48a9
@@ -19,10 +19,17 @@ VAE_FP16 = "vae_fp16"
|
||||
VAE_FP32 = "vae_fp32"
|
||||
UNET_FP16 = "unet_fp16"
|
||||
UNET_FP32 = "unet_fp32"
|
||||
IREE_EXTRA_ARGS = []
|
||||
|
||||
|
||||
def get_models():
|
||||
global IREE_EXTRA_ARGS
|
||||
if args.precision == "fp16":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir == True:
|
||||
return get_vae16(model_name=VAE_FP16), get_unet16_wrapped(
|
||||
model_name=UNET_FP16
|
||||
@@ -31,22 +38,19 @@ def get_models():
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
VAE_FP16,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
IREE_EXTRA_ARGS,
|
||||
), get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP16,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
IREE_EXTRA_ARGS,
|
||||
)
|
||||
|
||||
elif args.precision == "fp32":
|
||||
IREE_EXTRA_ARGS += [
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
]
|
||||
if args.import_mlir == True:
|
||||
return get_vae32(model_name=VAE_FP32), get_unet32_wrapped(
|
||||
model_name=UNET_FP32
|
||||
@@ -55,25 +59,21 @@ def get_models():
|
||||
return get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
VAE_FP32,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
IREE_EXTRA_ARGS,
|
||||
), get_shark_model(
|
||||
GCLOUD_BUCKET,
|
||||
UNET_FP32,
|
||||
[
|
||||
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
],
|
||||
IREE_EXTRA_ARGS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
IREE_EXTRA_ARGS.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
prompt = [args.prompt]
|
||||
|
||||
|
||||
@@ -64,5 +64,11 @@ p.add_argument(
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -18,7 +18,7 @@ import numpy as np
|
||||
import os
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device):
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
if device == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
@@ -30,7 +30,7 @@ def get_iree_device_args(device):
|
||||
if device in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args()
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
@@ -68,7 +68,7 @@ def compile_module_to_flatbuffer(
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += extra_args
|
||||
|
||||
|
||||
@@ -17,7 +17,11 @@
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
|
||||
def get_vulkan_triple_flag():
|
||||
def get_vulkan_triple_flag(extra_args=[]):
|
||||
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
|
||||
print(f"Using target triple from command line args")
|
||||
return None
|
||||
|
||||
vulkan_device_cmd = "vulkaninfo | grep deviceName"
|
||||
vulkan_device = run_cmd(vulkan_device_cmd).strip()
|
||||
if all(x in vulkan_device for x in ("Apple", "M1")):
|
||||
@@ -52,10 +56,10 @@ def get_vulkan_triple_flag():
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args():
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag()
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
Reference in New Issue
Block a user