Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989)

This commit is contained in:
Ean Garvey
2023-11-23 01:01:41 -06:00
committed by GitHub
parent ce38d49f05
commit da50a16242
4 changed files with 24 additions and 12 deletions

View File

@@ -147,7 +147,7 @@ jobs:
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'

View File

@@ -85,8 +85,9 @@ def clean_device_info(raw_device):
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = None
device_id = ""
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id

View File

@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: vulkan://{device_num}")
return vulkaninfo_list[device_num]
if isinstance(device_num, int):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: vulkan://{device_num}")
vulkan_device_name = vulkaninfo_list[device_num]
else:
from iree.runtime import get_driver
vulkan_device_driver = get_driver(device_num)
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
print(vulkan_device_name)
return vulkan_device_name
def get_os_name():

View File

@@ -809,6 +809,8 @@ def save_mlir(
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):
os.makedirs(dir)
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)