mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Fix conditionals.
This commit is contained in:
@@ -47,21 +47,21 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
+ stack_size_flag
|
||||
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] == "vulkan":
|
||||
if device == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "metal":
|
||||
if device == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device_uri[0] == "rocm":
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
|
||||
Reference in New Issue
Block a user