Don't convert device ID to int and fix .exe imports

This commit is contained in:
Ean Garvey
2023-05-12 15:18:51 -07:00
parent 38c7a9d2f5
commit 1d3b41670c
3 changed files with 9 additions and 6 deletions

View File

@@ -31,7 +31,7 @@ def get_iree_device_args(device, extra_args=[]):
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
device_num = int(device_uri[1])
device_num = device_uri[1]
else:
device_num = 0
@@ -46,7 +46,9 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(device_num=device_num,extra_args=extra_args)
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

View File

@@ -144,7 +144,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
return None
def get_iree_vulkan_args(device_num=0,extra_args=[]):
def get_iree_vulkan_args(device_num=0, extra_args=[]):
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = []
@@ -156,7 +156,9 @@ def get_iree_vulkan_args(device_num=0,extra_args=[]):
break
if vulkan_triple_flag is None:
vulkan_triple_flag = get_vulkan_triple_flag(device_num=device_num,extra_args=extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(
device_num=device_num, extra_args=extra_args
)
if vulkan_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)

View File

@@ -30,8 +30,8 @@ import os
import sys
from typing import Dict, List
import iree.compiler._mlir_libs
from iree.compiler import ir
from iree.compiler.transforms import ireec as ireec_trans
def model_annotation(
@@ -409,7 +409,6 @@ def shape_list_to_string(input):
def create_context() -> ir.Context:
context = ir.Context()
ireec_trans.register_all_dialects(context)
context.allow_unregistered_dialects = True
return context