mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Create specified dir if needed during save_mlir and fix vulkan device fetching without URI/ID (#1989)
This commit is contained in:
2
.github/workflows/test-models.yml
vendored
2
.github/workflows/test-models.yml
vendored
@@ -147,7 +147,7 @@ jobs:
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
|
||||
@@ -85,8 +85,9 @@ def clean_device_info(raw_device):
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
|
||||
device_id = ""
|
||||
if device in ["rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
|
||||
@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
return vulkaninfo_list[device_num]
|
||||
if isinstance(device_num, int):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
vulkan_device_name = vulkaninfo_list[device_num]
|
||||
else:
|
||||
from iree.runtime import get_driver
|
||||
|
||||
vulkan_device_driver = get_driver(device_num)
|
||||
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
|
||||
print(vulkan_device_name)
|
||||
return vulkan_device_name
|
||||
|
||||
|
||||
def get_os_name():
|
||||
|
||||
@@ -809,6 +809,8 @@ def save_mlir(
|
||||
dir = os.path.join(".", "shark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
Reference in New Issue
Block a user