Add support for automatic target triple selection for SD

This commit is contained in:
PhaneeshB
2022-12-21 16:15:45 +05:30
committed by Phaneesh Barwaria
parent b133a035a4
commit 2befe771b3
5 changed files with 180 additions and 85 deletions

View File

@@ -17,7 +17,7 @@ from tqdm.auto import tqdm
import numpy as np
from random import randint
from stable_args import args
from utils import get_shark_model, set_iree_runtime_flags
from utils import set_iree_runtime_flags
from opt_params import get_unet, get_vae, get_clip
from schedulers import (
SharkEulerDiscreteScheduler,

View File

@@ -5,22 +5,35 @@ from model_wrappers import (
get_clip_mlir,
)
from stable_args import args
from utils import get_shark_model
from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag
from utils import get_shark_model, map_device_to_name_path
from shark.iree_utils.vulkan_utils import get_vulkan_target_triple
BATCH_SIZE = len(args.prompts)
if BATCH_SIZE != 1:
sys.exit("Only batch size 1 is supported.")
# use tuned models only in the case of rdna3 cards.
if not args.iree_vulkan_target_triple:
vulkan_triple_flags = get_vulkan_triple_flag()
if vulkan_triple_flags and "rdna3" not in vulkan_triple_flags:
# global settings for device, iree-vulkan-target-triple and use_tuned flags
if "vulkan" not in args.device:
if args.use_tuned:
print("Tuned models not currently supported for device")
args.use_tuned = False
elif "rdna3" not in args.iree_vulkan_target_triple:
args.use_tuned = False
if args.use_tuned:
print("Using tuned models for rdna3 card")
else:
name, args.device = map_device_to_name_path(args.device)
triple = get_vulkan_target_triple(name)
print(f"Found device {name}. Using target triple {triple}")
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
if args.iree_vulkan_target_triple == "" and triple is not None:
args.iree_vulkan_target_triple = triple
# use tuned models only in the case of rdna3 cards.
if not args.iree_vulkan_target_triple:
if triple is not None and "rdna3" not in triple:
args.use_tuned = False
elif "rdna3" not in args.iree_vulkan_target_triple:
args.use_tuned = False
if args.use_tuned:
print("Using tuned models for rdna3 card")
def get_unet():

View File

@@ -4,7 +4,10 @@ import torch
from shark.shark_inference import SharkInference
from stable_args import args
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import set_iree_vulkan_runtime_flags
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
def _compile_module(shark_module, model_name, extra_args=[]):
@@ -86,3 +89,79 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
return
def get_all_devices(driver_name):
"""
Inputs: driver_name
Returns a list of all the available devices for a given driver sorted by
the iree path names of the device as in --list_devices option in iree.
Set `full_dict` flag to True to get a dict
with `path`, `name` and `device_id` for all devices
"""
from iree.runtime import get_driver
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return device_list_src
def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired combination of name/path.
"""
from shark.iree_utils._common import iree_device_map
driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()
def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping

View File

@@ -23,21 +23,27 @@ import re
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
if "://" in device:
device = device.split("://")[0]
if device == "cpu":
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
print(
f"Specific device selection only supported for vulkan now."
f"Proceeding with {device} as device."
)
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
return get_iree_cpu_args()
if device == "cuda":
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device in ["metal", "vulkan"]:
if device_uri[0] in ["metal", "vulkan"]:
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(extra_args=extra_args)
if device == "rocm":
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()

View File

@@ -26,9 +26,10 @@ def get_vulkan_device_name():
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print(
f"Found {len(vulkaninfo_list)} device names. choosing first one: {vulkaninfo_list[0]}"
)
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing first one: {vulkaninfo_list[0]}")
return vulkaninfo_list[0]
@@ -44,81 +45,77 @@ def get_os_name():
return "linux"
def get_vulkan_triple_flag(extra_args=[]):
if "-iree-vulkan-target-triple=" in " ".join(extra_args):
print(f"Using target triple from command line args")
return None
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
system_os = get_os_name()
vulkan_device = get_vulkan_device_name()
# Apple Targets
if all(x in vulkan_device for x in ("Apple", "M1")):
print(f"Found {vulkan_device} Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
elif all(x in vulkan_device for x in ("Apple", "M2")):
print("Found Apple M2 Device. Using m1-moltenvk-macos")
return "-iree-vulkan-target-triple=m1-moltenvk-macos"
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"
# Nvidia Targets
elif all(x in vulkan_device for x in ("RTX", "2080")):
print(
f"Found {vulkan_device} Device. Using turing-rtx2080-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx2080-{system_os}"
elif all(x in vulkan_device for x in ("A100", "SXM4")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3080-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3080-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "3090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "4090")):
print(
f"Found {vulkan_device} Device. Using ampere-rtx3090-{system_os}"
)
return f"-iree-vulkan-target-triple=ampere-rtx3090-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "4000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx4000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx4000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "5000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx5000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx5000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "6000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx6000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx6000-{system_os}"
elif all(x in vulkan_device for x in ("RTX", "8000")):
print(
f"Found {vulkan_device} Device. Using turing-rtx8000-{system_os}"
)
return f"-iree-vulkan-target-triple=turing-rtx8000-{system_os}"
elif all(x in device_name for x in ("RTX", "2080")):
triple = f"turing-rtx2080-{system_os}"
elif all(x in device_name for x in ("A100", "SXM4")):
triple = f"ampere-rtx3080-{system_os}"
elif all(x in device_name for x in ("RTX", "3090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "4090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "4000")):
triple = f"turing-rtx4000-{system_os}"
elif all(x in device_name for x in ("RTX", "5000")):
triple = f"turing-rtx5000-{system_os}"
elif all(x in device_name for x in ("RTX", "6000")):
triple = f"turing-rtx6000-{system_os}"
elif all(x in device_name for x in ("RTX", "8000")):
triple = f"turing-rtx8000-{system_os}"
# Amd Targets
elif all(x in vulkan_device for x in ("AMD", "7900")):
print(f"Found {vulkan_device} Device. Using rdna3-7900-{system_os}")
return f"-iree-vulkan-target-triple=rdna3-7900-{system_os}"
elif any(x in vulkan_device for x in ("AMD", "Radeon")):
print(f"Found AMD device. Using rdna2-unknown-{system_os}")
return f"-iree-vulkan-target-triple=rdna2-unknown-{system_os}"
elif all(x in device_name for x in ("AMD", "7900")):
triple = f"rdna3-7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
else:
triple = None
return triple
def get_vulkan_triple_flag(device_name=None, extra_args=[]):
for flag in extra_args:
if "-iree-vulkan-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None
vulkan_device = (
device_name if device_name is not None else get_vulkan_device_name()
)
triple = get_vulkan_target_triple(vulkan_device)
if triple is not None:
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
)
print(f"Target : {vulkan_device}")
return None
return f"-iree-vulkan-target-triple={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {vulkan_device}")
return None
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
vulkan_flag = []
vulkan_triple_flag = get_vulkan_triple_flag(extra_args)
vulkan_triple_flag = get_vulkan_triple_flag(extra_args=extra_args)
if vulkan_triple_flag is not None:
vulkan_flag.append(vulkan_triple_flag)
return vulkan_flag