mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for automatic target triple selection for SD
This commit is contained in:
committed by
Phaneesh Barwaria
parent
b133a035a4
commit
2befe771b3
@@ -17,7 +17,7 @@ from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from stable_args import args
|
||||
from utils import get_shark_model, set_iree_runtime_flags
|
||||
from utils import set_iree_runtime_flags
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
|
||||
@@ -5,22 +5,35 @@ from model_wrappers import (
|
||||
get_clip_mlir,
|
||||
)
|
||||
from stable_args import args
|
||||
from utils import get_shark_model
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
|
||||
from utils import get_shark_model, map_device_to_name_path
|
||||
from shark.iree_utils.vulkan_utils import get_vulkan_target_triple
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
# use tuned models only in the case of rdna3 cards.
|
||||
if not args.iree_vulkan_target_triple:
|
||||
vulkan_triple_flags = get_vulkan_triple_flag()
|
||||
if vulkan_triple_flags and "rdna3" not in vulkan_triple_flags:
|
||||
# global settings for device, iree-vulkan-target-triple and use_tuned flags
|
||||
if "vulkan" not in args.device:
|
||||
if args.use_tuned:
|
||||
print("Tuned models not currently supported for device")
|
||||
args.use_tuned = False
|
||||
elif "rdna3" not in args.iree_vulkan_target_triple:
|
||||
args.use_tuned = False
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for rdna3 card")
|
||||
else:
|
||||
name, args.device = map_device_to_name_path(args.device)
|
||||
triple = get_vulkan_target_triple(name)
|
||||
print(f"Found device {name}. Using target triple {triple}")
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
if args.iree_vulkan_target_triple == "" and triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
|
||||
# use tuned models only in the case of rdna3 cards.
|
||||
if not args.iree_vulkan_target_triple:
|
||||
if triple is not None and "rdna3" not in triple:
|
||||
args.use_tuned = False
|
||||
elif "rdna3" not in args.iree_vulkan_target_triple:
|
||||
args.use_tuned = False
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for rdna3 card")
|
||||
|
||||
|
||||
def get_unet():
|
||||
|
||||
@@ -4,7 +4,10 @@ import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from stable_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
@@ -86,3 +89,79 @@ def set_iree_runtime_flags():
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
Set `full_dict` flag to True to get a dict
|
||||
with `path`, `name` and `device_id` for all devices
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
@@ -23,21 +23,27 @@ import re
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
if device == "cpu":
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
|
||||
if device_uri[0] == "cpu":
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
return get_iree_cpu_args()
|
||||
if device == "cuda":
|
||||
if device_uri[0] == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device in ["metal", "vulkan"]:
|
||||
if device_uri[0] in ["metal", "vulkan"]:
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
|
||||
@@ -26,9 +26,10 @@ def get_vulkan_device_name():
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print(
|
||||
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
|
||||
)
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing first one: {vulkaninfo_list[0]}")
|
||||
return vulkaninfo_list[0]
|
||||
|
||||
|
||||
@@ -44,81 +45,77 @@ def get_os_name():
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(extra_args=[]):
|
||||
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
|
||||
print(f"Using target triple from command line args")
|
||||
return None
|
||||
def get_vulkan_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
Args:
|
||||
device_name (str): name of the hardware device to be used with vulkan
|
||||
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
system_os = get_os_name()
|
||||
vulkan_device = get_vulkan_device_name()
|
||||
# Apple Targets
|
||||
if all(x in vulkan_device for x in ("Apple", "M1")):
|
||||
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
elif all(x in vulkan_device for x in ("Apple", "M2")):
|
||||
print("Found Apple M2 Device. Using m1-moltenvk-macos")
|
||||
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
elif all(x in device_name for x in ("Apple", "M2")):
|
||||
triple = "m1-moltenvk-macos"
|
||||
|
||||
# Nvidia Targets
|
||||
elif all(x in vulkan_device for x in ("RTX", "2080")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx2080-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx2080-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("A100", "SXM4")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "3090")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "4090")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "4000")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx4000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx4000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "5000")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx5000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx5000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "6000")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx6000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx6000-{system_os}"
|
||||
elif all(x in vulkan_device for x in ("RTX", "8000")):
|
||||
print(
|
||||
f"Found {vulkan_device} Device. Using turing-rtx8000-{system_os}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple=turing-rtx8000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "2080")):
|
||||
triple = f"turing-rtx2080-{system_os}"
|
||||
elif all(x in device_name for x in ("A100", "SXM4")):
|
||||
triple = f"ampere-rtx3080-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "3090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4090")):
|
||||
triple = f"ampere-rtx3090-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "4000")):
|
||||
triple = f"turing-rtx4000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "5000")):
|
||||
triple = f"turing-rtx5000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "6000")):
|
||||
triple = f"turing-rtx6000-{system_os}"
|
||||
elif all(x in device_name for x in ("RTX", "8000")):
|
||||
triple = f"turing-rtx8000-{system_os}"
|
||||
|
||||
# Amd Targets
|
||||
elif all(x in vulkan_device for x in ("AMD", "7900")):
|
||||
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
|
||||
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
|
||||
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
|
||||
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
else:
|
||||
triple = None
|
||||
return triple
|
||||
|
||||
|
||||
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
vulkan_device = (
|
||||
device_name if device_name is not None else get_vulkan_device_name()
|
||||
)
|
||||
triple = get_vulkan_target_triple(vulkan_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
|
||||
)
|
||||
print(f"Target : {vulkan_device}")
|
||||
return None
|
||||
return f"-iree-vulkan-target-triple={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {vulkan_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_vulkan_args(extra_args=[]):
|
||||
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
vulkan_flag = []
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_flag.append(vulkan_triple_flag)
|
||||
return vulkan_flag
|
||||
|
||||
Reference in New Issue
Block a user