mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Revert "Add target triple selection for multiple cards" (#655)
This reverts commit acb905f0cc.
This commit is contained in:
committed by
GitHub
parent
72648aa9f2
commit
831f206cd0
@@ -11,11 +11,7 @@ from diffusers import (
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from stable_args import args
|
||||
from utils import (
|
||||
get_shark_model,
|
||||
set_iree_runtime_flags,
|
||||
make_qualified_device_name,
|
||||
)
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
import time
|
||||
import sys
|
||||
@@ -68,7 +64,7 @@ if __name__ == "__main__":
|
||||
sys.exit("More than one prompt is not supported yet.")
|
||||
if batch_size != len(neg_prompt):
|
||||
sys.exit("prompts and negative prompts must be of same length")
|
||||
make_qualified_device_name()
|
||||
|
||||
set_iree_runtime_flags()
|
||||
unet = get_unet()
|
||||
vae = get_vae()
|
||||
|
||||
@@ -5,7 +5,6 @@ from shark.shark_inference import SharkInference
|
||||
from stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
from shark.iree_utils._common import map_device_to_path
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -87,12 +86,3 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def make_qualified_device_name():
|
||||
# modify device name to be fully qualified device name
|
||||
# of the format driver://path
|
||||
# supported for vulkan as of now
|
||||
|
||||
if "vulkan" in args.device:
|
||||
args.device = map_device_to_path(args.device)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from iree.runtime import get_driver, get_device
|
||||
|
||||
|
||||
def run_cmd(cmd):
|
||||
@@ -38,57 +37,8 @@ def run_cmd(cmd):
|
||||
sys.exit("Exiting program due to error running:", cmd)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
Set `full_dict` flag to True to get a dict
|
||||
with `path`, `name` and `device_id` for all devices
|
||||
"""
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def create_map_device_to_key(driver, key):
|
||||
# key can only be path, name, device id
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = f"{device_list[0][key]}"
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = f"{device[key]}"
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = f"{device[key]}"
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_path(device):
|
||||
driver = device.split("://")[0]
|
||||
device_map = create_map_device_to_key(driver, "path")
|
||||
try:
|
||||
device_path = device_map[device]
|
||||
except KeyError:
|
||||
raise Exception(f"Device {device} is not a valid device.")
|
||||
return f"{driver}://{device_path}"
|
||||
|
||||
|
||||
def map_device_to_name(device):
|
||||
driver = device.split("://")[0]
|
||||
device_map = create_map_device_to_key(driver, "name")
|
||||
try:
|
||||
device_name = device_map[device]
|
||||
except KeyError:
|
||||
raise Exception(f"Device {device} is not a valid device.")
|
||||
return device_name
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 1)
|
||||
uri_parts = device.split("://", 2)
|
||||
if len(uri_parts) == 1:
|
||||
return _IREE_DEVICE_MAP[uri_parts[0]]
|
||||
else:
|
||||
|
||||
@@ -23,27 +23,21 @@ import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
if device == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
return get_iree_cpu_args()
|
||||
if device_uri[0] == "cuda":
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
if device in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(device, extra_args)
|
||||
if device_uri[0] == "rocm":
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
|
||||
@@ -15,10 +15,23 @@
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import map_device_to_name
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
|
||||
|
||||
def get_vulkan_device_name():
|
||||
vulkaninfo_dump = run_cmd("vulkaninfo").split(linesep)
|
||||
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print(
|
||||
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
|
||||
)
|
||||
return vulkaninfo_list[0]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
@@ -31,39 +44,62 @@ def get_os_name():
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device, extra_args=[]):
|
||||
def get_vulkan_triple_flag(extra_args=[]):
|
||||
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
|
||||
print(f"Using target triple from command line args")
|
||||
return None
|
||||
|
||||
system_os = get_os_name()
|
||||
vulkan_device = map_device_to_name(device)
|
||||
triple = None
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
# Apple Targets
|
||||
if all(x in vulkan_device for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif all(x in vulkan_device for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
print("Found Apple M2 Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
# Nvidia Targets
|
||||
elif all(x in vulkan_device for x in ("A100", "SXM4")):
|
||||
triple = f"ampere-rtx3080-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "3090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "4090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "4000")):
|
||||
triple = f"turing-rtx4000-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx4000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx4000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "5000")):
|
||||
triple = f"turing-rtx5000-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx5000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx5000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "6000")):
|
||||
triple = f"turing-rtx6000-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx6000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx6000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "8000")):
|
||||
triple = f"turing-rtx8000-{system_os}"
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx8000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx8000-{system_os}"
|
||||
# Amd Targets
|
||||
elif all(x in vulkan_device for x in ("AMD", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
|
||||
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
|
||||
else:
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
@@ -73,23 +109,17 @@ def get_vulkan_triple_flag(device, extra_args=[]):
|
||||
print(f"Target : {vulkan_device}")
|
||||
return None
|
||||
|
||||
print(f"Found {vulkan_device}. Using {triple}")
|
||||
return f"-iree-vulkan-target-triple={triple}"
|
||||
|
||||
|
||||
def get_iree_vulkan_args(device, extra_args=[]):
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device=device, extra_args=extra_args
|
||||
)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
import iree.runtime as ireert
|
||||
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user