mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
import os
|
|
import torch
|
|
from amdshark.amdshark_inference import AMDSharkInference
|
|
from upscaler_args import args
|
|
from amdshark.amdshark_importer import import_with_fx
|
|
from amdshark.iree_utils.vulkan_utils import (
|
|
set_iree_vulkan_runtime_flags,
|
|
get_vulkan_target_triple,
|
|
get_iree_vulkan_runtime_flags,
|
|
)
|
|
|
|
|
|
def _compile_module(amdshark_module, model_name, extra_args=[]):
|
|
if args.load_vmfb or args.save_vmfb:
|
|
device = (
|
|
args.device
|
|
if "://" not in args.device
|
|
else "-".join(args.device.split("://"))
|
|
)
|
|
extended_name = "{}_{}".format(model_name, device)
|
|
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
|
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
|
print(f"loading existing vmfb from: {vmfb_path}")
|
|
amdshark_module.load_module(vmfb_path, extra_args=extra_args)
|
|
else:
|
|
if args.save_vmfb:
|
|
print("Saving to {}".format(vmfb_path))
|
|
else:
|
|
print(
|
|
"No vmfb found. Compiling and saving to {}".format(
|
|
vmfb_path
|
|
)
|
|
)
|
|
path = amdshark_module.save_module(
|
|
os.getcwd(), extended_name, extra_args
|
|
)
|
|
amdshark_module.load_module(path, extra_args=extra_args)
|
|
else:
|
|
amdshark_module.compile(extra_args)
|
|
return amdshark_module
|
|
|
|
|
|
# Downloads the model from amdshark_tank and returns the amdshark_module.
|
|
def get_amdshark_model(tank_url, model_name, extra_args=[]):
|
|
from amdshark.amdshark_downloader import download_model
|
|
from amdshark.parser import amdshark_args
|
|
|
|
# Set local amdshark_tank cache directory.
|
|
# amdshark_args.local_tank_cache = args.local_tank_cache
|
|
|
|
mlir_model, func_name, inputs, golden_out = download_model(
|
|
model_name,
|
|
tank_url=tank_url,
|
|
frontend="torch",
|
|
)
|
|
amdshark_module = AMDSharkInference(
|
|
mlir_model, device=args.device, mlir_dialect="linalg"
|
|
)
|
|
return _compile_module(amdshark_module, model_name, extra_args)
|
|
|
|
|
|
# Converts the torch-module into a amdshark_module.
|
|
def compile_through_fx(
|
|
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
|
):
|
|
mlir_module, func_name = import_with_fx(
|
|
model, inputs, is_f16, f16_input_mask
|
|
)
|
|
amdshark_module = AMDSharkInference(
|
|
mlir_module,
|
|
device=args.device,
|
|
mlir_dialect="linalg",
|
|
)
|
|
|
|
return _compile_module(amdshark_module, model_name, extra_args)
|
|
|
|
|
|
def set_iree_runtime_flags():
|
|
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
|
if args.enable_rgp:
|
|
vulkan_runtime_flags += [
|
|
f"--enable_rgp=true",
|
|
f"--vulkan_debug_utils=true",
|
|
]
|
|
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
|
|
|
|
|
def get_all_devices(driver_name):
|
|
"""
|
|
Inputs: driver_name
|
|
Returns a list of all the available devices for a given driver sorted by
|
|
the iree path names of the device as in --list_devices option in iree.
|
|
"""
|
|
from iree.runtime import get_driver
|
|
|
|
driver = get_driver(driver_name)
|
|
device_list_src = driver.query_available_devices()
|
|
device_list_src.sort(key=lambda d: d["path"])
|
|
return device_list_src
|
|
|
|
|
|
def get_device_mapping(driver, key_combination=3):
|
|
"""This method ensures consistent device ordering when choosing
|
|
specific devices for execution
|
|
Args:
|
|
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
|
key_combination (int, optional): choice for mapping value for device name.
|
|
1 : path
|
|
2 : name
|
|
3 : (name, path)
|
|
Defaults to 3.
|
|
Returns:
|
|
dict: map to possible device names user can input mapped to desired combination of name/path.
|
|
"""
|
|
from amdshark.iree_utils._common import iree_device_map
|
|
|
|
driver = iree_device_map(driver)
|
|
device_list = get_all_devices(driver)
|
|
device_map = dict()
|
|
|
|
def get_output_value(dev_dict):
|
|
if key_combination == 1:
|
|
return f"{driver}://{dev_dict['path']}"
|
|
if key_combination == 2:
|
|
return dev_dict["name"]
|
|
if key_combination == 3:
|
|
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
|
|
|
# mapping driver name to default device (driver://0)
|
|
device_map[f"{driver}"] = get_output_value(device_list[0])
|
|
for i, device in enumerate(device_list):
|
|
# mapping with index
|
|
device_map[f"{driver}://{i}"] = get_output_value(device)
|
|
# mapping with full path
|
|
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
|
return device_map
|
|
|
|
|
|
def map_device_to_name_path(device, key_combination=3):
|
|
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
|
Args:
|
|
device (str): user
|
|
key_combination (int, optional): choice for mapping value for device name.
|
|
1 : path
|
|
2 : name
|
|
3 : (name, path)
|
|
Defaults to 3.
|
|
Raises:
|
|
ValueError:
|
|
Returns:
|
|
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
|
"""
|
|
driver = device.split("://")[0]
|
|
device_map = get_device_mapping(driver, key_combination)
|
|
try:
|
|
device_mapping = device_map[device]
|
|
except KeyError:
|
|
raise ValueError(f"Device '{device}' is not a valid device.")
|
|
return device_mapping
|
|
|
|
|
|
def set_init_device_flags():
|
|
if "vulkan" in args.device:
|
|
# set runtime flags for vulkan.
|
|
set_iree_runtime_flags()
|
|
|
|
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
|
device_name, args.device = map_device_to_name_path(args.device)
|
|
if not args.iree_vulkan_target_triple:
|
|
triple = get_vulkan_target_triple(device_name)
|
|
if triple is not None:
|
|
args.iree_vulkan_target_triple = triple
|
|
print(
|
|
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
|
)
|
|
elif "cuda" in args.device:
|
|
args.device = "cuda"
|
|
elif "cpu" in args.device:
|
|
args.device = "cpu"
|
|
|
|
# set max_length based on availability.
|
|
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
|
args.max_length = 77
|
|
elif args.variant == "openjourney":
|
|
args.max_length = 64
|
|
|
|
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
|
if (
|
|
args.variant in ["openjourney", "dreamlike"]
|
|
or args.precision != "fp16"
|
|
or "vulkan" not in args.device
|
|
or "rdna3" not in args.iree_vulkan_target_triple
|
|
):
|
|
args.use_tuned = False
|
|
print("Tuned models are currently not supported for this setting.")
|
|
|
|
elif args.use_base_vae and args.variant != "stablediffusion":
|
|
args.use_tuned = False
|
|
print("Tuned models are currently not supported for this setting.")
|
|
|
|
if args.use_tuned:
|
|
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
|
|
|
|
|
# Utility to get list of devices available.
|
|
def get_available_devices():
|
|
def get_devices_by_name(driver_name):
|
|
from amdshark.iree_utils._common import iree_device_map
|
|
|
|
device_list = []
|
|
try:
|
|
driver_name = iree_device_map(driver_name)
|
|
device_list_dict = get_all_devices(driver_name)
|
|
print(f"{driver_name} devices are available.")
|
|
except:
|
|
print(f"{driver_name} devices are not available.")
|
|
else:
|
|
for i, device in enumerate(device_list_dict):
|
|
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
|
return device_list
|
|
|
|
set_iree_runtime_flags()
|
|
|
|
available_devices = []
|
|
vulkan_devices = get_devices_by_name("vulkan")
|
|
available_devices.extend(vulkan_devices)
|
|
cuda_devices = get_devices_by_name("cuda")
|
|
available_devices.extend(cuda_devices)
|
|
available_devices.append("cpu")
|
|
return available_devices
|