mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Don't convert device ID to int and fix .exe imports
This commit is contained in:
@@ -31,7 +31,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
device_num = int(device_uri[1])
|
||||
device_num = device_uri[1]
|
||||
else:
|
||||
device_num = 0
|
||||
|
||||
@@ -46,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(device_num=device_num,extra_args=extra_args)
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(device_num=0,extra_args=[]):
|
||||
def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(device_num=0,extra_args=[]):
|
||||
break
|
||||
|
||||
if vulkan_triple_flag is None:
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(device_num=device_num,extra_args=extra_args)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
|
||||
@@ -30,8 +30,8 @@ import os
|
||||
import sys
|
||||
from typing import Dict, List
|
||||
|
||||
import iree.compiler._mlir_libs
|
||||
from iree.compiler import ir
|
||||
from iree.compiler.transforms import ireec as ireec_trans
|
||||
|
||||
|
||||
def model_annotation(
|
||||
@@ -409,7 +409,6 @@ def shape_list_to_string(input):
|
||||
|
||||
def create_context() -> ir.Context:
|
||||
context = ir.Context()
|
||||
ireec_trans.register_all_dialects(context)
|
||||
context.allow_unregistered_dialects = True
|
||||
return context
|
||||
|
||||
|
||||
Reference in New Issue
Block a user