Purge unused code and patch out iree runtime handling from init

This commit is contained in:
Ean Garvey
2024-06-03 18:00:05 -05:00
parent 59600456be
commit dac7a29eef
27 changed files with 144 additions and 11471 deletions

View File

@@ -12,10 +12,7 @@ from tqdm.auto import tqdm
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
from apps.shark_studio.api.controlnet import control_adapter_map
@@ -31,11 +28,8 @@ from apps.shark_studio.modules.img_processing import (
save_output_img,
)
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
from subprocess import check_output
EMPTY_SD_MAP = {
"clip": None,
"scheduler": None,
@@ -67,7 +61,6 @@ def load_script(source, module_name):
:param module_name: name of module to register in sys.modules
:return: loaded module
"""
spec = importlib.util.spec_from_file_location(module_name, source)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
@@ -118,10 +111,15 @@ class StableDiffusion:
self.dynamic_steps = False
self.model_map = custom_module.MODEL_MAP
elif self.is_sdxl:
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)
self.turbine_pipe = SharkSDXLPipeline
self.dynamic_steps = False
self.model_map = EMPTY_SDXL_MAP
else:
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
self.turbine_pipe = SharkSDPipeline
self.dynamic_steps = True
self.model_map = EMPTY_SD_MAP
@@ -207,6 +205,10 @@ class StableDiffusion:
self.compiled_pipeline = compiled_pipeline
if custom_weights:
from apps.shark_studio.modules.ckpt_processing import (
preprocessCKPT,
save_irpa,
)
custom_weights = os.path.join(
get_checkpoints_path("checkpoints"),
safe_name(self.base_model_id.split("/")[-1]),
@@ -534,11 +536,11 @@ if __name__ == "__main__":
global_obj._init()
sd_json = view_json_file(
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
get_resource_path(os.path.join(cmd_opts.config_dir, cmd_opts.default_config))
)
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
# for arg in vars(cmd_opts):
# if arg in sd_kwargs:
# sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -11,15 +11,11 @@ from pathlib import Path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info
# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
def get_available_devices():
return ["AMD Radeon 780M => rocm"]
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
@@ -49,7 +45,7 @@ def get_available_devices():
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
set_iree_runtime_flags()
#set_iree_runtime_flags()
available_devices = []
rocm_devices = get_devices_by_name("rocm")
@@ -96,55 +92,6 @@ def get_available_devices():
return available_devices
def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
set_iree_runtime_flags()
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_vulkan_target_triple:
triple = get_vulkan_target_triple(device_name)
if triple is not None:
cmd_opts.iree_vulkan_target_triple = triple
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_vulkan_target_triple}."
)
elif "cuda" in cmd_opts.device:
cmd_opts.device = "cuda"
elif "metal" in cmd_opts.device:
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
if not cmd_opts.iree_metal_target_platform:
from shark.iree_utils.metal_utils import get_metal_target_triple
triple = get_metal_target_triple(device_name)
if triple is not None:
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
print(
f"Found device {device_name}. Using target triple "
f"{cmd_opts.iree_metal_target_platform}."
)
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if cmd_opts.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if cmd_opts.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
@@ -213,6 +160,7 @@ def get_all_devices(driver_name):
driver = get_driver(driver_name)
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
del driver
return device_list_src
@@ -281,115 +229,115 @@ def get_opt_flags(model, precision="fp16"):
return iree_flags
def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Args:
device (str): user
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Raises:
ValueError:
Returns:
str / tuple: returns the mapping str or tuple of mapping str for
the device depending on key_combination value
"""
driver = device.split("://")[0]
device_map = get_device_mapping(driver, key_combination)
try:
device_mapping = device_map[device]
except KeyError:
raise ValueError(f"Device '{device}' is not a valid device.")
return device_mapping
# def map_device_to_name_path(device, key_combination=3):
# """Gives the appropriate device data (supported name/path) for user
# selected execution device
# Args:
# device (str): user
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Raises:
# ValueError:
# Returns:
# str / tuple: returns the mapping str or tuple of mapping str for
# the device depending on key_combination value
# """
# driver = device.split("://")[0]
# device_map = get_device_mapping(driver, key_combination)
# try:
# device_mapping = device_map[device]
# except KeyError:
# raise ValueError(f"Device '{device}' is not a valid device.")
# return device_mapping
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
# def get_devices_by_name(driver_name):
# from shark.iree_utils._common import iree_device_map
device_list = []
try:
driver_name = iree_device_map(driver_name)
device_list_dict = get_all_devices(driver_name)
print(f"{driver_name} devices are available.")
except:
print(f"{driver_name} devices are not available.")
else:
cpu_name = get_cpu_info()["brand_raw"]
for i, device in enumerate(device_list_dict):
device_name = (
cpu_name if device["name"] == "default" else device["name"]
)
if "local" in driver_name:
device_list.append(
f"{device_name} => {driver_name.replace('local', 'cpu')}"
)
else:
# for drivers with single devices
# let the default device be selected without any indexing
if len(device_list_dict) == 1:
device_list.append(f"{device_name} => {driver_name}")
else:
device_list.append(f"{device_name} => {driver_name}://{i}")
return device_list
# device_list = []
# try:
# driver_name = iree_device_map(driver_name)
# device_list_dict = get_all_devices(driver_name)
# print(f"{driver_name} devices are available.")
# except:
# print(f"{driver_name} devices are not available.")
# else:
# cpu_name = get_cpu_info()["brand_raw"]
# for i, device in enumerate(device_list_dict):
# device_name = (
# cpu_name if device["name"] == "default" else device["name"]
# )
# if "local" in driver_name:
# device_list.append(
# f"{device_name} => {driver_name.replace('local', 'cpu')}"
# )
# else:
# # for drivers with single devices
# # let the default device be selected without any indexing
# if len(device_list_dict) == 1:
# device_list.append(f"{device_name} => {driver_name}")
# else:
# device_list.append(f"{device_name} => {driver_name}://{i}")
# return device_list
set_iree_runtime_flags()
# set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
# available_devices = []
# from shark.iree_utils.vulkan_utils import (
# get_all_vulkan_devices,
# )
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices
# vulkaninfo_list = get_all_vulkan_devices()
# vulkan_devices = []
# id = 0
# for device in vulkaninfo_list:
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
# id += 1
# if id != 0:
# print(f"vulkan devices are available.")
# available_devices.extend(vulkan_devices)
# metal_devices = get_devices_by_name("metal")
# available_devices.extend(metal_devices)
# cuda_devices = get_devices_by_name("cuda")
# available_devices.extend(cuda_devices)
# rocm_devices = get_devices_by_name("rocm")
# available_devices.extend(rocm_devices)
# cpu_device = get_devices_by_name("cpu-sync")
# available_devices.extend(cpu_device)
# cpu_device = get_devices_by_name("cpu-task")
# available_devices.extend(cpu_device)
# return available_devices
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# # Generate and return a new seed if the provided one is not in the
# # supported range (including -1)
# def sanitize_seed(seed: int | str):
# seed = int(seed)
# uint32_info = np.iinfo(np.uint32)
# uint32_min, uint32_max = uint32_info.min, uint32_info.max
# if seed < uint32_min or seed >= uint32_max:
# seed = randint(uint32_min, uint32_max)
# return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
# # take a seed expression in an input format and convert it to
# # a list of integers, where possible
# def parse_seed_input(seed_input: str | list | int):
# if isinstance(seed_input, str):
# try:
# seed_input = json.loads(seed_input)
# except (ValueError, TypeError):
# seed_input = None
if isinstance(seed_input, int):
return [seed_input]
# if isinstance(seed_input, int):
# return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
# if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
# return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# raise TypeError(
# "Seed input must be an integer or an array of integers in JSON format"
# )

View File

@@ -594,6 +594,11 @@ p.add_argument(
##############################################################################
# Web UI flags
##############################################################################
p.add_argument(
"--default_config",
default="sdxl-turbo.json",
type=str,
)
p.add_argument(
"--webui",

View File

@@ -170,7 +170,7 @@ def webui():
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="Shark Studio 2.0 Beta",
title="Shark Studio 2.0",
) as studio_web:
amd_logo = Image.open(amdlogo_loc)
gr.Image(

View File

View File

@@ -1,22 +0,0 @@
import torch
from shark.parser import parser
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
parser.add_argument(
"--model_name",
type=str,
required=True,
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
)
load_args, unknown = parser.parse_known_args()
if __name__ == "__main__":
model_name = load_args.model_name
test_input = torch.randint(2, (1, 128))
shark_module = SharkHFBenchmarkRunner(
model_name, (test_input,), jit_trace=True
)
shark_module.benchmark_c()
shark_module.benchmark_python((test_input,))
shark_module.benchmark_torch(test_input)
shark_module.benchmark_onnx(test_input)

View File

@@ -1,181 +0,0 @@
import torch
from shark.shark_benchmark_runner import SharkBenchmarkRunner
from shark.parser import shark_args
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from onnxruntime.transformers.benchmark import (
run_pytorch,
run_tensorflow,
run_onnxruntime,
)
from onnxruntime.transformers.huggingface_models import MODELS
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
import os
import psutil
class OnnxFusionOptions(object):
def __init__(self):
self.disable_gelu = False
self.disable_layer_norm = False
self.disable_attention = False
self.disable_skip_layer_norm = False
self.disable_embed_layer_norm = False
self.disable_bias_skip_layer_norm = False
self.disable_bias_gelu = False
self.enable_gelu_approximation = False
self.use_mask_index = False
self.no_attention_mask = False
class HuggingFaceLanguage(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
# SharkRunner derived class with Benchmarking capabilities.
def __init__(
self,
model_name: str,
input: tuple,
dynamic: bool = False,
device: str = None,
jit_trace: bool = False,
from_aot: bool = False,
frontend: str = "torch",
):
self.device = device if device is not None else shark_args.device
if self.device == "gpu":
raise ValueError(
"Currently GPU Benchmarking is not supported due to OOM from ORT."
)
self.model_name = model_name
model = HuggingFaceLanguage(model_name)
SharkBenchmarkRunner.__init__(
self,
model,
input,
dynamic,
self.device,
jit_trace,
from_aot,
frontend,
)
def benchmark_torch(self, inputs):
use_gpu = self.device == "gpu"
# Set set the model's layer number to automatic.
config_modifier = ConfigModifier(None)
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
verbose = False
result = run_pytorch(
use_gpu,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
False,
cache_dir,
verbose,
)
print(
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)
# TODO: Currently non-functional due to TF runtime error. There might be some issue with, initializing TF.
def benchmark_tf(self, inputs):
use_gpu = self.device == "gpu"
# Set set the model's layer number to automatic.
config_modifier = ConfigModifier(None)
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
verbose = False
result = run_tensorflow(
use_gpu,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
cache_dir,
verbose,
)
print(
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)
def benchmark_onnx(self, inputs):
if self.model_name not in MODELS:
print(
f"{self.model_name} is currently not supported in ORT's HF. Check \
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
for currently supported models. Exiting benchmark ONNX."
)
return
use_gpu = self.device == "gpu"
num_threads = psutil.cpu_count(logical=False)
batch_sizes = [inputs.shape[0]]
sequence_lengths = [inputs.shape[-1]]
cache_dir = os.path.join(".", "cache_models")
onnx_dir = os.path.join(".", "onnx_models")
verbose = False
input_counts = [1]
optimize_onnx = True
validate_onnx = False
disable_ort_io_binding = False
use_raw_attention_mask = True
model_fusion_statistics = {}
overwrite = False
model_source = "pt" # Either "pt" or "tf"
provider = None
config_modifier = ConfigModifier(None)
onnx_args = OnnxFusionOptions()
result = run_onnxruntime(
use_gpu,
provider,
[self.model_name],
None,
config_modifier,
Precision.FLOAT32,
num_threads,
batch_sizes,
sequence_lengths,
shark_args.num_iterations,
input_counts,
optimize_onnx,
validate_onnx,
cache_dir,
onnx_dir,
verbose,
overwrite,
disable_ort_io_binding,
use_raw_attention_mask,
model_fusion_statistics,
model_source,
onnx_args,
)
print(
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
)

View File

@@ -1,231 +0,0 @@
from shark.shark_inference import SharkInference
from shark.iree_utils._common import check_device_drivers
import torch
import tensorflow as tf
import numpy as np
import torchvision.models as models
from transformers import (
AutoModelForSequenceClassification,
BertTokenizer,
TFBertModel,
)
import importlib
import pytest
import unittest
torch.manual_seed(0)
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
##################### Tensorflow Hugging Face LM Models ###################################
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
# Create a set of 2-dimensional inputs
tf_bert_input = [
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
]
class TFHuggingFaceLanguage(tf.Module):
def __init__(self, hf_model_name):
super(TFHuggingFaceLanguage, self).__init__()
# Create a BERT trainer with the created network.
self.m = TFBertModel.from_pretrained(hf_model_name, from_pt=True)
# Invoke the trainer model on the inputs. This causes the layer to be built.
self.m.predict = lambda x, y, z: self.m.call(
input_ids=x, attention_mask=y, token_type_ids=z, training=False
)
@tf.function(input_signature=tf_bert_input, jit_compile=True)
def forward(self, input_ids, attention_mask, token_type_ids):
return self.m.predict(input_ids, attention_mask, token_type_ids)
def get_TFhf_model(name):
model = TFHuggingFaceLanguage(name)
tokenizer = BertTokenizer.from_pretrained(name)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
)
for key in encoded_input:
encoded_input[key] = tf.expand_dims(
tf.convert_to_tensor(encoded_input[key]), 0
)
test_input = (
encoded_input["input_ids"],
encoded_input["attention_mask"],
encoded_input["token_type_ids"],
)
actual_out = model.forward(*test_input)
return model, test_input, actual_out
##################### Hugging Face LM Models ###################################
class HuggingFaceLanguage(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
hf_model_name, # The pretrained model.
num_labels=2, # The number of output labels--2 for binary classification.
output_attentions=False, # Whether the model returns attentions weights.
output_hidden_states=False, # Whether the model returns all hidden-states.
torchscript=True,
)
def forward(self, tokens):
return self.model.forward(tokens)[0]
def get_hf_model(name):
model = HuggingFaceLanguage(name)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randint(2, (1, 128))
actual_out = model(test_input)
return model, test_input, actual_out
################################################################################
##################### Torch Vision Models ###################################
class VisionModule(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.train(False)
def forward(self, input):
return self.model.forward(input)
def get_vision_model(torch_model):
model = VisionModule(torch_model)
# TODO: Currently the test input is set to (1,128)
test_input = torch.randn(1, 3, 224, 224)
actual_out = model(test_input)
return model, test_input, actual_out
############################# Benchmark Tests ####################################
pytest_benchmark_param = pytest.mark.parametrize(
("dynamic", "device"),
[
pytest.param(False, "cpu"),
# TODO: Language models are failing for dynamic case..
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"cuda",
marks=pytest.mark.skipif(
check_device_drivers("cuda"), reason="nvidia-smi not found"
),
),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",
marks=pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
),
),
pytest.param(
True,
"vulkan",
marks=pytest.mark.skipif(
check_device_drivers("vulkan"),
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
),
),
],
)
@pytest.mark.skipif(
importlib.util.find_spec("iree.tools") is None,
reason="Cannot find tools to import TF",
)
@pytest_benchmark_param
def test_bench_minilm_torch(dynamic, device):
model, test_input, act_out = get_hf_model(
"microsoft/MiniLM-L12-H384-uncased"
)
shark_module = SharkInference(
model,
(test_input,),
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.compile()
shark_module.benchmark_all((test_input,))
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False
@pytest.mark.skipif(
importlib.util.find_spec("iree.tools") is None,
reason="Cannot find tools to import TF",
)
@pytest_benchmark_param
def test_bench_distilbert(dynamic, device):
model, test_input, act_out = get_TFhf_model("distilbert-base-uncased")
shark_module = SharkInference(
model,
test_input,
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False
@pytest.mark.skip(reason="XLM Roberta too large to test.")
@pytest_benchmark_param
def test_bench_xlm_roberta(dynamic, device):
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
shark_module = SharkInference(
model,
test_input,
device=device,
dynamic=dynamic,
jit_trace=True,
benchmark_mode=True,
)
try:
# If becnhmarking succesful, assert success/True.
shark_module.set_frontend("tensorflow")
shark_module.compile()
shark_module.benchmark_all(test_input)
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False

View File

@@ -1,45 +0,0 @@
import torch
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
import importlib
import pytest
torch.manual_seed(0)
############################# HF Benchmark Tests ####################################
# Test running benchmark module without failing.
pytest_benchmark_param = pytest.mark.parametrize(
("dynamic", "device"),
[
pytest.param(False, "cpu"),
# TODO: Language models are failing for dynamic case..
pytest.param(True, "cpu", marks=pytest.mark.skip),
],
)
@pytest.mark.skipif(
importlib.util.find_spec("onnxruntime") is None,
reason="Cannot find ONNXRUNTIME.",
)
@pytest_benchmark_param
def test_HFbench_minilm_torch(dynamic, device):
model_name = "bert-base-uncased"
test_input = torch.randint(2, (1, 128))
try:
shark_module = SharkHFBenchmarkRunner(
model_name,
(test_input,),
jit_trace=True,
dynamic=dynamic,
device=device,
)
shark_module.benchmark_c()
shark_module.benchmark_python((test_input,))
shark_module.benchmark_torch(test_input)
shark_module.benchmark_onnx(test_input)
# If becnhmarking succesful, assert success/True.
assert True
except Exception as e:
# If anything happen during benchmarking, assert False/failure.
assert False

3
cpp/.gitignore vendored
View File

@@ -1,3 +0,0 @@
*.mlir
*.vmfb
*.ini

View File

@@ -1,52 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
cmake_minimum_required(VERSION 3.21...3.23)
#-------------------------------------------------------------------------------
# Project configuration
#-------------------------------------------------------------------------------
project(iree-samples C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
#-------------------------------------------------------------------------------
# Core project dependency
#-------------------------------------------------------------------------------
message(STATUS "Fetching core IREE repo (this may take a few minutes)...")
# Note: for log output, set -DFETCHCONTENT_QUIET=OFF,
# see https://gitlab.kitware.com/cmake/cmake/-/issues/18238#note_440475
include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF
GIT_PROGRESS ON
USES_TERMINAL_DOWNLOAD ON
)
# Extend module path to find MLIR CMake modules.
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
# Disable core project features not needed for these out of tree samples.
set(IREE_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(IREE_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(iree)
FetchContent_GetProperties(iree SOURCE_DIR IREE_SOURCE_DIR)
#-------------------------------------------------------------------------------
# Individual samples
#-------------------------------------------------------------------------------
add_subdirectory(vulkan_gui)

View File

@@ -1,82 +0,0 @@
# SHARK C/C++ Samples
These C/C++ samples can be built using CMake. The samples depend on the main
SHARK-Runtime project's C/C++ sources, including both the runtime and the compiler.
Individual samples may require additional dependencies. Watch CMake's output
for information about which you are missing for individual samples.
On Windows we recommend using https://github.com/microsoft/vcpkg to download packages for
your system. The general setup flow looks like
*Install and activate SHARK*
```bash
source shark.venv/bin/activate #follow main repo instructions to setup your venv
```
*Install Dependencies*
```bash
vcpkg install [library] --triplet [your platform]
vcpkg integrate install
# Then pass `-DCMAKE_TOOLCHAIN_FILE=[check logs for path]` when configuring CMake
```
In Ubuntu Linux you can install
```bash
sudo apt install libsdl2-dev
```
*Build*
```bash
cd cpp
cmake -GNinja -B build/
cmake --build build/
```
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
```bash
python save_img.py
```
Note that this requires tensorflow, e.g.
```bash
python -m pip install tensorflow
```
*Run the vulkan_gui*
```bash
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
```
## Other models
A tool for benchmarking other models is built and can be invoked with a command like the following
```bash
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
```
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

View File

@@ -1,18 +0,0 @@
import numpy as np
import tensorflow as tf
from shark.shark_inference import SharkInference
def load_and_preprocess_image(fname: str):
image = tf.io.read_file(fname)
image = tf.image.decode_image(image, channels=3)
image = tf.image.resize(image, (224, 224))
image = image[tf.newaxis, :]
# preprocessing pipeline
input_tensor = tf.keras.applications.resnet50.preprocess_input(image)
return input_tensor
data = load_and_preprocess_image("dog_imagenet.jpg").numpy()
data.tofile("dog.bin")

View File

@@ -1,84 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
message(STATUS "Missing LLVM backend and/or embeddded elf loader, skipping vision_inference sample")
return()
endif()
# vcpkg install stb
# tested with version 2021-09-10
find_package(Stb)
if(NOT Stb_FOUND)
message(STATUS "Could not find Stb, skipping vision inference sample")
return()
endif()
# Compile mnist.mlir to mnist.vmfb.
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
set(_COMPILE_ARGS)
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
list(APPEND _COMPILE_ARGS "-o")
list(APPEND _COMPILE_ARGS "mnist.vmfb")
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "${IREE_SOURCE_DIR}/samples/models/mnist.mlir"
)
# Embed mnist.vmfb into a C file as mnist_bytecode_module_c.[h/c]
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
set(_EMBED_ARGS)
list(APPEND _EMBED_ARGS "--output_header=mnist_bytecode_module_c.h")
list(APPEND _EMBED_ARGS "--output_impl=mnist_bytecode_module_c.c")
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vision_inference_mnist_bytecode_module")
list(APPEND _EMBED_ARGS "--flatten")
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb")
add_custom_command(
OUTPUT "mnist_bytecode_module_c.h" "mnist_bytecode_module_c.c"
COMMAND ${_EMBED_DATA_EXECUTABLE} ${_EMBED_ARGS}
DEPENDS ${_EMBED_DATA_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
)
# Define a library target for mnist_bytecode_module_c.
add_library(iree_samples_vision_inference_mnist_bytecode_module_c OBJECT)
target_sources(iree_samples_vision_inference_mnist_bytecode_module_c
PRIVATE
mnist_bytecode_module_c.h
mnist_bytecode_module_c.c
)
# Define the sample executable.
set(_NAME "iree-run-mnist-module")
add_executable(${_NAME} "")
target_sources(${_NAME}
PRIVATE
"image_util.h"
"image_util.c"
"iree-run-mnist-module.c"
)
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-run-mnist-module")
target_include_directories(${_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
)
target_include_directories(${_NAME} PRIVATE
${Stb_INCLUDE_DIR}
)
target_link_libraries(${_NAME}
iree_base_base
iree_base_tracing
iree_hal_hal
iree_runtime_runtime
iree_samples_vision_inference_mnist_bytecode_module_c
)
# Define a target that copies the test image into the build directory.
add_custom_target(iree_samples_vision_inference_test_image
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/mnist_test.png" "${CMAKE_CURRENT_BINARY_DIR}/mnist_test.png")
add_dependencies(${_NAME} iree_samples_vision_inference_test_image)
message(STATUS "Configured vision_inference sample successfully")

View File

@@ -1,8 +0,0 @@
# Vision Inference Sample (C code)
This sample demonstrates how to run a MNIST handwritten digit detection vision
model on an image using IREE's C API.
A similar sample is implemented using a Python script and IREE's command line
tools over in the primary iree repository at
https://github.com/iree-org/iree/tree/main/samples/vision_inference

View File

@@ -1,224 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "image_util.h"
#include <math.h>
#include "iree/base/internal/flags.h"
#include "iree/base/tracing.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
const uint8_t* pixel_data, iree_host_size_t buffer_length,
const float* input_range, iree_host_size_t range_length,
float* out_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);
if (range_length != 2) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"range defined as 2-element [min, max] array.");
}
float input_scale = fabsf(input_range[1] - input_range[0]) / 2.0f;
float input_offset = (input_range[0] + input_range[1]) / 2.0f;
const float kUint8Mean = 127.5f;
for (int i = 0; i < buffer_length; ++i) {
out_buffer[i] =
(((float)(pixel_data[i])) - kUint8Mean) / kUint8Mean * input_scale +
input_offset;
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}
iree_status_t iree_tools_utils_load_pixel_data_impl(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
int img_dims[3];
if (stbi_info(filename.data, img_dims, &(img_dims[1]), &(img_dims[2])) == 0) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
(int)filename.size, filename.data);
}
if (!(element_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 ||
element_type == IREE_HAL_ELEMENT_TYPE_SINT_8 ||
element_type == IREE_HAL_ELEMENT_TYPE_UINT_8)) {
char element_type_str[16];
IREE_RETURN_IF_ERROR(iree_hal_format_element_type(
element_type, sizeof(element_type_str), element_type_str, NULL));
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"element type %s not supported", element_type_str);
}
switch (shape_rank) {
case 2: { // Assume tensor <height x width>
if (img_dims[2] != 1 || (shape[0] != img_dims[1]) ||
(shape[1] != img_dims[0])) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim "x%" PRIdim, img_dims[0],
img_dims[1], img_dims[2], shape[1], shape[0]);
}
break;
}
case 3: { // Assume tensor <height x width x channel>
if (shape[0] != img_dims[1] || shape[1] != img_dims[0] ||
shape[2] != img_dims[2]) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim
"x%" PRIdim "x%" PRIdim,
img_dims[0], img_dims[1], img_dims[2], shape[1],
shape[0], shape[2]);
}
break;
}
case 4: { // Assume tensor <batch x height x width x channel>
if (shape[1] != img_dims[1] || shape[2] != img_dims[0] ||
shape[3] != img_dims[2]) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"image size: %dx%dx%d, expected: %" PRIdim
"x%" PRIdim "x%" PRIdim,
img_dims[0], img_dims[1], img_dims[2], shape[2],
shape[1], shape[3]);
}
break;
}
default:
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"Input buffer shape rank %" PRIhsz " not supported", shape_rank);
}
// Drop the alpha channel if present.
int req_ch = (img_dims[2] >= 3) ? 3 : 0;
*out_pixel_data = stbi_load(filename.data, img_dims, &(img_dims[1]),
&(img_dims[2]), req_ch);
if (*out_pixel_data == NULL) {
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
(int)filename.size, filename.data);
}
*out_buffer_length =
img_dims[0] * img_dims[1] * (img_dims[2] > 3 ? 3 : img_dims[2]);
return iree_ok_status();
}
iree_status_t iree_tools_utils_load_pixel_data(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t result = iree_tools_utils_load_pixel_data_impl(
filename, shape, shape_rank, element_type, out_pixel_data,
out_buffer_length);
IREE_TRACE_ZONE_END(z0);
return result;
}
iree_status_t iree_tools_utils_buffer_view_from_image(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_ZONE_BEGIN(z0);
*out_buffer_view = NULL;
if (element_type != IREE_HAL_ELEMENT_TYPE_SINT_8 &&
element_type != IREE_HAL_ELEMENT_TYPE_UINT_8) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"element type should be i8 or u8");
}
iree_status_t result;
uint8_t* pixel_data = NULL;
iree_host_size_t buffer_length;
result = iree_tools_utils_load_pixel_data(
filename, shape, shape_rank, element_type, &pixel_data, &buffer_length);
if (iree_status_is_ok(result)) {
iree_host_size_t element_byte =
iree_hal_element_dense_byte_count(element_type);
// SINT_8 and UINT_8 perform direct buffer wrap.
result = iree_hal_buffer_view_allocate_buffer(
allocator, shape_rank, shape, element_type,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
.access = IREE_HAL_MEMORY_ACCESS_READ,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
IREE_HAL_BUFFER_USAGE_TRANSFER,
},
iree_make_const_byte_span(pixel_data, element_byte * buffer_length),
out_buffer_view);
}
stbi_image_free(pixel_data);
IREE_TRACE_ZONE_END(z0);
return result;
}
typedef struct iree_tools_utils_buffer_view_load_params_t {
const uint8_t* pixel_data;
iree_host_size_t pixel_data_length;
const float* input_range;
iree_host_size_t input_range_length;
} iree_tools_utils_buffer_view_load_params_t;
static iree_status_t iree_tools_utils_buffer_view_load_image_rescaled(
iree_hal_buffer_mapping_t* mapping, void* user_data) {
iree_tools_utils_buffer_view_load_params_t* params =
(iree_tools_utils_buffer_view_load_params_t*)user_data;
return iree_tools_utils_pixel_rescaled_to_buffer(
params->pixel_data, params->pixel_data_length, params->input_range,
params->input_range_length, (float*)mapping->contents.data);
}
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, const float* input_range,
iree_host_size_t input_range_length,
iree_hal_buffer_view_t** out_buffer_view) {
IREE_TRACE_ZONE_BEGIN(z0);
*out_buffer_view = NULL;
if (element_type != IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
IREE_TRACE_ZONE_END(z0);
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"element type should be f32");
}
// Classic row-major image layout.
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
// Load pixel data from the file into a new host memory allocation (the only
// interface stb_image provides). A real application would want to use the
// generation callback to directly decode the image into the target mapped
// device buffer.
uint8_t* pixel_data = NULL;
iree_host_size_t buffer_length = 0;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_tools_utils_load_pixel_data(filename, shape, shape_rank,
element_type, &pixel_data,
&buffer_length));
iree_tools_utils_buffer_view_load_params_t params = {
.pixel_data = pixel_data,
.pixel_data_length = buffer_length,
.input_range = input_range,
.input_range_length = input_range_length,
};
iree_status_t status = iree_hal_buffer_view_generate_buffer(
allocator, shape_rank, shape, element_type, encoding_type,
(iree_hal_buffer_params_t){
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_MAPPING,
},
iree_tools_utils_buffer_view_load_image_rescaled, &params,
out_buffer_view);
stbi_image_free(pixel_data);
IREE_TRACE_ZONE_END(z0);
return status;
}

View File

@@ -1,77 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
#define IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/buffer_view.h"
#if __cplusplus
extern "C" {
#endif // __cplusplus
// Loads the image at |filename| into |out_pixel_data| and sets
// |out_buffer_length| to its length.
//
// The image dimension must match the width, height, and channel in|shape|,
// while 2 <= |shape_rank| <= 4 to match the image tensor format.
//
// The file must be in a format supported by stb_image.h.
// The returned |out_pixel_data| buffer must be released by the caller.
iree_status_t iree_tools_utils_load_pixel_data(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length);
// Parse the content in an image file in |filename| into a HAL buffer view
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
//
// The |element_type| has to be SINT_8 or UINT_8. For FLOAT_32, use
// |iree_tools_utils_buffer_view_from_image_rescaled| instead.
//
// The returned |out_buffer_view| must be released by the caller.
iree_status_t iree_tools_utils_buffer_view_from_image(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view);
// Parse the content in an image file in |filename| into a HAL buffer view
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
// The value in |out_buffer_view| is rescaled with |input_range|.
//
// The |element_type| has to be FLOAT_32, For SINT_8 or UINT_8, use
// |iree_tools_utils_buffer_view_from_image| instead.
//
// The returned |out_buffer_view| must be released by the caller.
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
const iree_string_view_t filename, const iree_hal_dim_t* shape,
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
iree_hal_allocator_t* allocator, const float* input_range,
iree_host_size_t input_range_length,
iree_hal_buffer_view_t** out_buffer_view);
// Normalize uint8_t |pixel_data| of the size |buffer_length| to float buffer
// |out_buffer| with the range |input_range|.
//
// float32_x = (uint8_x - 127.5) / 127.5 * input_scale + input_offset, where
// input_scale = abs(|input_range[0]| - |input_range[1]| / 2
// input_offset = |input_range[0]| + |input_range[1]| / 2
//
// |out_buffer| needs to be allocated before the call.
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
const uint8_t* pixel_data, iree_host_size_t pixel_count,
const float* input_range, iree_host_size_t input_range_length,
float* out_buffer);
#if __cplusplus
}
#endif // __cplusplus
#endif // IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_

View File

@@ -1,121 +0,0 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// This sample uses image_util to load a hand-written image as an
// iree_hal_buffer_view_t then passes it to the bytecode module built from
// mnist.mlir on the CPU backend with the local-task driver.
#include <float.h>
#include "image_util.h"
#include "iree/runtime/api.h"
#include "mnist_bytecode_module_c.h"
iree_status_t Run(const iree_string_view_t image_path) {
iree_runtime_instance_options_t instance_options;
iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
&instance_options);
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
iree_runtime_instance_t* instance = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
&instance_options, iree_allocator_system(), &instance));
// TODO(#5724): move device selection into the compiled modules.
iree_hal_device_t* device = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
instance, iree_make_cstring_view("local-task"), &device));
// Create one session per loaded module to hold the module state.
iree_runtime_session_options_t session_options;
iree_runtime_session_options_initialize(&session_options);
iree_runtime_session_t* session = NULL;
IREE_RETURN_IF_ERROR(iree_runtime_session_create_with_device(
instance, &session_options, device,
iree_runtime_instance_host_allocator(instance), &session));
iree_hal_device_release(device);
const struct iree_file_toc_t* module_file =
iree_samples_vision_inference_mnist_bytecode_module_create();
IREE_RETURN_IF_ERROR(iree_runtime_session_append_bytecode_module_from_memory(
session, iree_make_const_byte_span(module_file->data, module_file->size),
iree_allocator_null()));
iree_runtime_call_t call;
IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
session, iree_make_cstring_view("module.predict"), &call));
// Prepare the input hal buffer view with image_util library.
// The input of the mmist model is single 28x28 pixel image as a
// tensor<1x28x28x1xf32>, with pixels in [0.0, 1.0].
iree_hal_buffer_view_t* buffer_view = NULL;
iree_hal_dim_t buffer_shape[] = {1, 28, 28, 1};
iree_hal_element_type_t hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
float input_range[2] = {0.0f, 1.0f};
IREE_RETURN_IF_ERROR(
iree_tools_utils_buffer_view_from_image_rescaled(
image_path, buffer_shape, IREE_ARRAYSIZE(buffer_shape),
hal_element_type, iree_hal_device_allocator(device), input_range,
IREE_ARRAYSIZE(input_range), &buffer_view),
"load image");
IREE_RETURN_IF_ERROR(
iree_runtime_call_inputs_push_back_buffer_view(&call, buffer_view));
iree_hal_buffer_view_release(buffer_view);
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
// Get the result buffers from the invocation.
iree_hal_buffer_view_t* ret_buffer_view = NULL;
IREE_RETURN_IF_ERROR(
iree_runtime_call_outputs_pop_front_buffer_view(&call, &ret_buffer_view));
// Read back the results. The output of the mnist model is a 1x10 prediction
// confidence values for each digit in [0, 9].
float predictions[1 * 10] = {0.0f};
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout()));
iree_hal_buffer_view_release(ret_buffer_view);
// Get the highest index from the output.
float result_val = FLT_MIN;
int result_idx = 0;
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
if (predictions[i] > result_val) {
result_val = predictions[i];
result_idx = i;
}
}
fprintf(stdout, "Detected number: %d\n", result_idx);
iree_runtime_call_deinitialize(&call);
iree_runtime_session_release(session);
iree_runtime_instance_release(instance);
return iree_ok_status();
}
int main(int argc, char** argv) {
if (argc > 2) {
fprintf(stderr, "Usage: iree-run-mnist-module <image file>\n");
return -1;
}
iree_string_view_t image_path;
if (argc == 1) {
image_path = iree_make_cstring_view("mnist_test.png");
} else {
image_path = iree_make_cstring_view(argv[1]);
}
iree_status_t result = Run(image_path);
if (!iree_status_is_ok(result)) {
iree_status_fprint(stderr, result);
iree_status_ignore(result);
return -1;
}
iree_status_ignore(result);
return 0;
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 261 B

View File

@@ -1,116 +0,0 @@
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV OR
NOT IREE_HAL_DRIVER_VULKAN)
message(STATUS "Missing Vulkan backend and/or driver, skipping vulkan_gui sample")
return()
endif()
# This target statically links against Vulkan.
# One way to achieve this is by installing the Vulkan SDK from
# https://vulkan.lunarg.com/.
include(FindVulkan)
if(NOT Vulkan_FOUND)
message(STATUS "Could not find Vulkan, skipping vulkan_gui sample")
return()
endif()
# vcpkg install sdl2[vulkan]
# tested with versions 2.0.14#4 - 2.0.22#1
find_package(SDL2)
if(NOT SDL2_FOUND)
message(STATUS "Could not find SDL2, skipping vulkan_gui sample")
return()
endif()
FetchContent_Declare(
imgui
GIT_REPOSITORY https://github.com/ocornut/imgui
GIT_TAG master
)
FetchContent_MakeAvailable(imgui)
# Dear ImGui
set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
message("Looking for Imgui in ${IMGUI_DIR}")
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
function(iree_vulkan_sample)
cmake_parse_arguments(
_RULE
""
"NAME"
"SRCS"
${ARGN}
)
# Define the sample executable.
set(_NAME "${_RULE_NAME}")
set(SRCS "${_RULE_SRCS}")
add_executable(${_NAME} "")
target_sources(${_NAME}
PRIVATE
${SRCS}
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
"${IMGUI_DIR}/imgui.cpp"
"${IMGUI_DIR}/imgui_draw.cpp"
"${IMGUI_DIR}/imgui_demo.cpp"
"${IMGUI_DIR}/imgui_tables.cpp"
"${IMGUI_DIR}/imgui_widgets.cpp"
)
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
target_include_directories(${_NAME} PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
)
target_link_libraries(${_NAME}
SDL2::SDL2
Vulkan::Vulkan
iree_runtime_runtime
iree_base_internal_main
iree_hal_drivers_vulkan_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_vm_cc
iree_tooling_vm_util_cc
iree_tooling_context_util
)
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
else()
set(_GUI_LINKOPTS "")
endif()
target_link_options(${_NAME}
PRIVATE
${_GUI_LINKOPTS}
)
endfunction()
iree_vulkan_sample(
NAME
iree-samples-resnet-vulkan-gui
SRCS
vulkan_resnet_inference_gui.cc
)
iree_vulkan_sample(
NAME
iree-vulkan-gui
SRCS
vulkan_inference_gui.cc
)
message(STATUS "Configured vulkan_gui sample successfully")

View File

@@ -1,4 +0,0 @@
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -1,957 +0,0 @@
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Vulkan Graphics + IREE API Integration Sample.
#include <SDL.h>
#include <SDL_vulkan.h>
#include <imgui.h>
#include <imgui_impl_sdl.h>
#include <imgui_impl_vulkan.h>
#include <vulkan/vulkan.h>
#include <cstring>
#include <set>
#include <vector>
#include <fstream>
#include <array>
#include <cstdio>
#include <cstdlib>
#include <iterator>
#include <string>
#include <utility>
#include "iree/hal/drivers/vulkan/api.h"
// IREE's C API:
#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
#include "iree/modules/hal/module.h"
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/ref_cc.h"
// iree-run-module
#include "iree/base/internal/flags.h"
#include "iree/base/status_cc.h"
#include "iree/base/tracing.h"
#include "iree/modules/hal/types.h"
#include "iree/tooling/comparison.h"
#include "iree/tooling/context_util.h"
#include "iree/tooling/vm_util_cc.h"
// Other dependencies (helpers, etc.)
#include "iree/base/internal/main.h"
#define IMGUI_UNLIMITED_FRAME_RATE
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
IREE_FLAG(string, entry_function, "",
"Name of a function contained in the module specified by module_file "
"to run.");
// TODO(benvanik): move --function_input= flag into a util.
static iree_status_t parse_function_io(iree_string_view_t flag_name,
void* storage,
iree_string_view_t value) {
auto* list = (std::vector<std::string>*)storage;
list->push_back(std::string(value.data, value.size));
return iree_ok_status();
}
static void print_function_io(iree_string_view_t flag_name, void* storage,
FILE* file) {
auto* list = (std::vector<std::string>*)storage;
if (list->empty()) {
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
} else {
for (size_t i = 0; i < list->size(); ++i) {
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
list->at(i).c_str());
}
}
}
static std::vector<std::string> FLAG_function_inputs;
IREE_FLAG_CALLBACK(
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
"An input (a) value or (b) buffer of the format:\n"
" (a) scalar value\n"
" value\n"
" e.g.: --function_input=\"3.14\"\n"
" (b) buffer:\n"
" [shape]xtype=[value]\n"
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Raw binary files can be read to provide buffer contents:\n"
" 2x2xi32=@some/file.bin\n"
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
" @some.npy\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");
typedef struct iree_file_toc_t {
const char* name; // the file's original name
char* data; // beginning of the file
size_t size; // length of the file
} iree_file_toc_t;
bool load_file(const char* filename, char** pOut, size_t* pSize)
{
FILE* f = fopen(filename, "rb");
if (f == NULL)
{
fprintf(stderr, "Can't open %s\n", filename);
return false;
}
fseek(f, 0L, SEEK_END);
*pSize = ftell(f);
fseek(f, 0L, SEEK_SET);
*pOut = (char*)malloc(*pSize);
size_t size = fread(*pOut, *pSize, 1, f);
fclose(f);
return size != 0;
}
static VkAllocationCallbacks* g_Allocator = NULL;
static VkInstance g_Instance = VK_NULL_HANDLE;
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
static VkDevice g_Device = VK_NULL_HANDLE;
static uint32_t g_QueueFamily = (uint32_t)-1;
static VkQueue g_Queue = VK_NULL_HANDLE;
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
static ImGui_ImplVulkanH_Window g_MainWindowData;
static uint32_t g_MinImageCount = 2;
static bool g_SwapChainRebuild = false;
static int g_SwapChainResizeWidth = 0;
static int g_SwapChainResizeHeight = 0;
static void check_vk_result(VkResult err) {
if (err == 0) return;
fprintf(stderr, "VkResult: %d\n", err);
abort();
}
// Returns the names of the Vulkan layers used for the given IREE
// |extensibility_set| and |features|.
std::vector<const char*> GetIreeLayers(
iree_hal_vulkan_extensibility_set_t extensibility_set,
iree_hal_vulkan_features_t features) {
iree_host_size_t required_count;
iree_hal_vulkan_query_extensibility_set(
features, extensibility_set, /*string_capacity=*/0, &required_count,
/*out_string_values=*/NULL);
std::vector<const char*> layers(required_count);
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
layers.size(), &required_count,
layers.data());
return layers;
}
// Returns the names of the Vulkan extensions used for the given IREE
// |extensibility_set| and |features|.
std::vector<const char*> GetIreeExtensions(
iree_hal_vulkan_extensibility_set_t extensibility_set,
iree_hal_vulkan_features_t features) {
iree_host_size_t required_count;
iree_hal_vulkan_query_extensibility_set(
features, extensibility_set, /*string_capacity=*/0, &required_count,
/*out_string_values=*/NULL);
std::vector<const char*> extensions(required_count);
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
extensions.size(), &required_count,
extensions.data());
return extensions;
}
// Returns the names of the Vulkan extensions used for the given IREE
// |vulkan_features|.
std::vector<const char*> GetDeviceExtensions(
VkPhysicalDevice physical_device,
iree_hal_vulkan_features_t vulkan_features) {
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED,
vulkan_features);
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
vulkan_features);
uint32_t extension_count = 0;
check_vk_result(vkEnumerateDeviceExtensionProperties(
physical_device, nullptr, &extension_count, nullptr));
std::vector<VkExtensionProperties> extension_properties(extension_count);
check_vk_result(vkEnumerateDeviceExtensionProperties(
physical_device, nullptr, &extension_count, extension_properties.data()));
// Merge extensions lists, including optional and required for simplicity.
std::set<const char*> ext_set;
ext_set.insert("VK_KHR_swapchain");
ext_set.insert(iree_required_extensions.begin(),
iree_required_extensions.end());
for (int i = 0; i < iree_optional_extensions.size(); ++i) {
const char* optional_extension = iree_optional_extensions[i];
for (int j = 0; j < extension_count; ++j) {
if (strcmp(optional_extension, extension_properties[j].extensionName) ==
0) {
ext_set.insert(optional_extension);
break;
}
}
}
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
return extensions;
}
std::vector<const char*> GetInstanceLayers(
iree_hal_vulkan_features_t vulkan_features) {
// Query the layers that IREE wants / needs.
std::vector<const char*> required_layers = GetIreeLayers(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features);
std::vector<const char*> optional_layers = GetIreeLayers(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features);
// Query the layers that are available on the Vulkan ICD.
uint32_t layer_property_count = 0;
check_vk_result(
vkEnumerateInstanceLayerProperties(&layer_property_count, NULL));
std::vector<VkLayerProperties> layer_properties(layer_property_count);
check_vk_result(vkEnumerateInstanceLayerProperties(&layer_property_count,
layer_properties.data()));
// Match between optional/required and available layers.
std::vector<const char*> layers;
for (const char* layer_name : required_layers) {
bool found = false;
for (const auto& layer_property : layer_properties) {
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
found = true;
layers.push_back(layer_name);
break;
}
}
if (!found) {
fprintf(stderr, "Required layer %s not available\n", layer_name);
abort();
}
}
for (const char* layer_name : optional_layers) {
for (const auto& layer_property : layer_properties) {
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
layers.push_back(layer_name);
break;
}
}
}
return layers;
}
std::vector<const char*> GetInstanceExtensions(
SDL_Window* window, iree_hal_vulkan_features_t vulkan_features) {
// Ask SDL for its list of required instance extensions.
uint32_t sdl_extensions_count = 0;
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, NULL);
std::vector<const char*> sdl_extensions(sdl_extensions_count);
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count,
sdl_extensions.data());
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED,
vulkan_features);
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL,
vulkan_features);
// Merge extensions lists, including optional and required for simplicity.
std::set<const char*> ext_set;
ext_set.insert(sdl_extensions.begin(), sdl_extensions.end());
ext_set.insert(iree_required_extensions.begin(),
iree_required_extensions.end());
ext_set.insert(iree_optional_extensions.begin(),
iree_optional_extensions.end());
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
return extensions;
}
void SetupVulkan(iree_hal_vulkan_features_t vulkan_features,
const char** instance_layers, uint32_t instance_layers_count,
const char** instance_extensions,
uint32_t instance_extensions_count,
const VkAllocationCallbacks* allocator, VkInstance* instance,
uint32_t* queue_family_index,
VkPhysicalDevice* physical_device, VkQueue* queue,
VkDevice* device, VkDescriptorPool* descriptor_pool) {
VkResult err;
// Create Vulkan Instance
{
VkInstanceCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
create_info.enabledLayerCount = instance_layers_count;
create_info.ppEnabledLayerNames = instance_layers;
create_info.enabledExtensionCount = instance_extensions_count;
create_info.ppEnabledExtensionNames = instance_extensions;
err = vkCreateInstance(&create_info, allocator, instance);
check_vk_result(err);
}
// Select GPU
{
uint32_t gpu_count;
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, NULL);
check_vk_result(err);
IM_ASSERT(gpu_count > 0);
VkPhysicalDevice* gpus =
(VkPhysicalDevice*)malloc(sizeof(VkPhysicalDevice) * gpu_count);
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, gpus);
check_vk_result(err);
// Use the first reported GPU for simplicity.
*physical_device = gpus[0];
VkPhysicalDeviceProperties properties;
vkGetPhysicalDeviceProperties(*physical_device, &properties);
fprintf(stdout, "Selected Vulkan device: '%s'\n", properties.deviceName);
free(gpus);
}
// Select queue family. We want a single queue with graphics and compute for
// simplicity, but we could also discover and use separate queues for each.
{
uint32_t count;
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, NULL);
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
sizeof(VkQueueFamilyProperties) * count);
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
for (uint32_t i = 0; i < count; i++) {
if (queues[i].queueFlags &
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
*queue_family_index = i;
break;
}
}
free(queues);
IM_ASSERT(*queue_family_index != (uint32_t)-1);
}
// Create Logical Device (with 1 queue)
{
std::vector<const char*> device_extensions =
GetDeviceExtensions(*physical_device, vulkan_features);
const float queue_priority[] = {1.0f};
VkDeviceQueueCreateInfo queue_info = {};
queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
queue_info.queueFamilyIndex = *queue_family_index;
queue_info.queueCount = 1;
queue_info.pQueuePriorities = queue_priority;
VkDeviceCreateInfo create_info = {};
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
create_info.queueCreateInfoCount = 1;
create_info.pQueueCreateInfos = &queue_info;
create_info.enabledExtensionCount =
static_cast<uint32_t>(device_extensions.size());
create_info.ppEnabledExtensionNames = device_extensions.data();
// Enable timeline semaphores.
VkPhysicalDeviceFeatures2 features2;
memset(&features2, 0, sizeof(features2));
features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
create_info.pNext = &features2;
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
memset(&semaphore_features, 0, sizeof(semaphore_features));
semaphore_features.sType =
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES;
semaphore_features.pNext = features2.pNext;
features2.pNext = &semaphore_features;
semaphore_features.timelineSemaphore = VK_TRUE;
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
check_vk_result(err);
vkGetDeviceQueue(*device, *queue_family_index, 0, queue);
}
// Create Descriptor Pool
{
VkDescriptorPoolSize pool_sizes[] = {
{VK_DESCRIPTOR_TYPE_SAMPLER, 1000},
{VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 1000},
{VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1000},
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC, 1000},
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1000},
{VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT, 1000}};
VkDescriptorPoolCreateInfo pool_info = {};
pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
pool_info.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
pool_info.maxSets = 1000 * IREE_ARRAYSIZE(pool_sizes);
pool_info.poolSizeCount = (uint32_t)IREE_ARRAYSIZE(pool_sizes);
pool_info.pPoolSizes = pool_sizes;
err =
vkCreateDescriptorPool(*device, &pool_info, allocator, descriptor_pool);
check_vk_result(err);
}
}
void SetupVulkanWindow(ImGui_ImplVulkanH_Window* wd,
const VkAllocationCallbacks* allocator,
VkInstance instance, uint32_t queue_family_index,
VkPhysicalDevice physical_device, VkDevice device,
VkSurfaceKHR surface, int width, int height,
uint32_t min_image_count) {
wd->Surface = surface;
// Check for WSI support
VkBool32 res;
vkGetPhysicalDeviceSurfaceSupportKHR(physical_device, queue_family_index,
wd->Surface, &res);
if (res != VK_TRUE) {
fprintf(stderr, "Error no WSI support on physical device 0\n");
exit(-1);
}
// Select Surface Format
const VkFormat requestSurfaceImageFormat[] = {
VK_FORMAT_B8G8R8A8_UNORM, VK_FORMAT_R8G8B8A8_UNORM,
VK_FORMAT_B8G8R8_UNORM, VK_FORMAT_R8G8B8_UNORM};
const VkColorSpaceKHR requestSurfaceColorSpace =
VK_COLORSPACE_SRGB_NONLINEAR_KHR;
wd->SurfaceFormat = ImGui_ImplVulkanH_SelectSurfaceFormat(
physical_device, wd->Surface, requestSurfaceImageFormat,
(size_t)IREE_ARRAYSIZE(requestSurfaceImageFormat),
requestSurfaceColorSpace);
// Select Present Mode
#ifdef IMGUI_UNLIMITED_FRAME_RATE
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_MAILBOX_KHR,
VK_PRESENT_MODE_IMMEDIATE_KHR,
VK_PRESENT_MODE_FIFO_KHR};
#else
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_FIFO_KHR};
#endif
wd->PresentMode = ImGui_ImplVulkanH_SelectPresentMode(
physical_device, wd->Surface, &present_modes[0],
IREE_ARRAYSIZE(present_modes));
// Create SwapChain, RenderPass, Framebuffer, etc.
IM_ASSERT(min_image_count >= 2);
ImGui_ImplVulkanH_CreateOrResizeWindow(instance, physical_device, device, wd,
queue_family_index, allocator, width,
height, min_image_count);
// Set clear color.
ImVec4 clear_color = ImVec4(0.45f, 0.55f, 0.60f, 1.00f);
memcpy(&wd->ClearValue.color.float32[0], &clear_color, 4 * sizeof(float));
}
void RenderFrame(ImGui_ImplVulkanH_Window* wd, VkDevice device, VkQueue queue) {
VkResult err;
VkSemaphore image_acquired_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].ImageAcquiredSemaphore;
VkSemaphore render_complete_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
err = vkAcquireNextImageKHR(device, wd->Swapchain, UINT64_MAX,
image_acquired_semaphore, VK_NULL_HANDLE,
&wd->FrameIndex);
check_vk_result(err);
ImGui_ImplVulkanH_Frame* fd = &wd->Frames[wd->FrameIndex];
{
err = vkWaitForFences(
device, 1, &fd->Fence, VK_TRUE,
UINT64_MAX); // wait indefinitely instead of periodically checking
check_vk_result(err);
err = vkResetFences(device, 1, &fd->Fence);
check_vk_result(err);
}
{
err = vkResetCommandPool(device, fd->CommandPool, 0);
check_vk_result(err);
VkCommandBufferBeginInfo info = {};
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
err = vkBeginCommandBuffer(fd->CommandBuffer, &info);
check_vk_result(err);
}
{
VkRenderPassBeginInfo info = {};
info.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO;
info.renderPass = wd->RenderPass;
info.framebuffer = fd->Framebuffer;
info.renderArea.extent.width = wd->Width;
info.renderArea.extent.height = wd->Height;
info.clearValueCount = 1;
info.pClearValues = &wd->ClearValue;
vkCmdBeginRenderPass(fd->CommandBuffer, &info, VK_SUBPASS_CONTENTS_INLINE);
}
// Record Imgui Draw Data and draw funcs into command buffer
ImGui_ImplVulkan_RenderDrawData(ImGui::GetDrawData(), fd->CommandBuffer);
// Submit command buffer
vkCmdEndRenderPass(fd->CommandBuffer);
{
VkPipelineStageFlags wait_stage =
VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
VkSubmitInfo info = {};
info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
info.waitSemaphoreCount = 1;
info.pWaitSemaphores = &image_acquired_semaphore;
info.pWaitDstStageMask = &wait_stage;
info.commandBufferCount = 1;
info.pCommandBuffers = &fd->CommandBuffer;
info.signalSemaphoreCount = 1;
info.pSignalSemaphores = &render_complete_semaphore;
err = vkEndCommandBuffer(fd->CommandBuffer);
check_vk_result(err);
err = vkQueueSubmit(queue, 1, &info, fd->Fence);
check_vk_result(err);
}
}
void PresentFrame(ImGui_ImplVulkanH_Window* wd, VkQueue queue) {
VkSemaphore render_complete_semaphore =
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
VkPresentInfoKHR info = {};
info.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
info.waitSemaphoreCount = 1;
info.pWaitSemaphores = &render_complete_semaphore;
info.swapchainCount = 1;
info.pSwapchains = &wd->Swapchain;
info.pImageIndices = &wd->FrameIndex;
VkResult err = vkQueuePresentKHR(queue, &info);
check_vk_result(err);
wd->SemaphoreIndex =
(wd->SemaphoreIndex + 1) %
wd->ImageCount; // Now we can use the next set of semaphores
}
static void CleanupVulkan() {
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
vkDestroyDevice(g_Device, g_Allocator);
vkDestroyInstance(g_Instance, g_Allocator);
}
static void CleanupVulkanWindow() {
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
g_Allocator);
}
namespace iree {
extern "C" int iree_main(int argc, char** argv) {
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
if (argc > 1) {
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
// dashes for flags.
printf(
"[ERROR] unexpected positional argument (expected none)."
" Did you use pass a flag with a single dash ('-')?"
" Use '--' instead.\n");
return 1;
}
// --------------------------------------------------------------------------
// Create a window.
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
fprintf(stderr, "Failed to initialize SDL\n");
abort();
return 1;
}
// Setup window
// clang-format off
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
// clang-format on
SDL_Window* window = SDL_CreateWindow(
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
if (window == nullptr)
{
const char* sdl_err = SDL_GetError();
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
abort();
return 1;
}
// Setup Vulkan
iree_hal_vulkan_features_t iree_vulkan_features =
static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
std::vector<const char*> extensions =
GetInstanceExtensions(window, iree_vulkan_features);
SetupVulkan(iree_vulkan_features, layers.data(),
static_cast<uint32_t>(layers.size()), extensions.data(),
static_cast<uint32_t>(extensions.size()), g_Allocator,
&g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
&g_Device, &g_DescriptorPool);
// Create Window Surface
VkSurfaceKHR surface;
VkResult err;
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
fprintf(stderr, "Failed to create Vulkan surface.\n");
abort();
return 1;
}
// Create Framebuffers
int w, h;
SDL_GetWindowSize(window, &w, &h);
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
// Setup Dear ImGui context
IMGUI_CHECKVERSION();
ImGui::CreateContext();
ImGuiIO& io = ImGui::GetIO();
(void)io;
ImGui::StyleColorsDark();
// Setup Platform/Renderer bindings
ImGui_ImplSDL2_InitForVulkan(window);
ImGui_ImplVulkan_InitInfo init_info = {};
init_info.Instance = g_Instance;
init_info.PhysicalDevice = g_PhysicalDevice;
init_info.Device = g_Device;
init_info.QueueFamily = g_QueueFamily;
init_info.Queue = g_Queue;
init_info.PipelineCache = g_PipelineCache;
init_info.DescriptorPool = g_DescriptorPool;
init_info.Allocator = g_Allocator;
init_info.MinImageCount = g_MinImageCount;
init_info.ImageCount = wd->ImageCount;
init_info.CheckVkResultFn = check_vk_result;
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
// Upload Fonts
{
// Use any command queue
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
err = vkResetCommandPool(g_Device, command_pool, 0);
check_vk_result(err);
VkCommandBufferBeginInfo begin_info = {};
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
err = vkBeginCommandBuffer(command_buffer, &begin_info);
check_vk_result(err);
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
VkSubmitInfo end_info = {};
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
end_info.commandBufferCount = 1;
end_info.pCommandBuffers = &command_buffer;
err = vkEndCommandBuffer(command_buffer);
check_vk_result(err);
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
check_vk_result(err);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_DestroyFontUploadObjects();
}
// Demo state.
bool show_iree_window = true;
// --------------------------------------------------------------------------
// Setup IREE.
// Check API version.
iree_api_version_t actual_version;
iree_status_t status =
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
if (iree_status_is_ok(status)) {
fprintf(stdout, "IREE runtime API version: %d\n", actual_version);
} else {
fprintf(stderr, "Unsupported runtime API version: %d\n", actual_version);
abort();
}
// Create a runtime Instance.
iree_vm_instance_t* iree_instance = nullptr;
IREE_CHECK_OK(
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
// Register HAL drivers and VM module types.
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
iree_hal_driver_registry_default()));
IREE_CHECK_OK(iree_hal_module_register_all_types(iree_instance));
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
fprintf(stdout, "Creating Vulkan driver/device\n");
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
&iree_vk_syms));
// Create the driver sharing our VkInstance.
iree_hal_driver_t* iree_vk_driver = nullptr;
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_driver_options_t driver_options;
driver_options.api_version = VK_API_VERSION_1_0;
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
iree_allocator_system(), &iree_vk_driver));
// Create a device sharing our VkDevice and queue.
// We could also create a separate (possibly low priority) compute queue for
// IREE, and/or provide a dedicated transfer queue.
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
iree_hal_vulkan_queue_set_t compute_queue_set;
compute_queue_set.queue_family_index = g_QueueFamily;
compute_queue_set.queue_indices = 1 << 0;
iree_hal_vulkan_queue_set_t transfer_queue_set;
transfer_queue_set.queue_indices = 0;
iree_hal_device_t* iree_vk_device = nullptr;
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
device_identifier, &driver_options.device_options, iree_vk_syms,
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
// Create a HAL module using the HAL device.
iree_vm_module_t* hal_module = nullptr;
IREE_CHECK_OK(iree_hal_module_create(iree_instance, iree_vk_device,
IREE_HAL_MODULE_FLAG_NONE,
iree_allocator_system(), &hal_module));
// Load bytecode module
//iree_file_toc_t module_file_toc;
//const char network_model[] = "resnet50_tf.vmfb";
//fprintf(stdout, "Loading: %s\n", network_model);
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
//{
// abort();
// return 1;
//}
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
iree_vm_module_t* bytecode_module = nullptr;
iree_status_t module_status = iree_tooling_load_module_from_flags(
iree_instance, iree_allocator_system(), &bytecode_module);
if (!iree_status_is_ok(module_status))
return -1;
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
// iree_instance,
// iree_const_byte_span_t{
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
// module_file_toc.size},
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
//// Query for details about what is in the loaded module.
//iree_vm_module_signature_t bytecode_module_signature =
// iree_vm_module_signature(bytecode_module);
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
// bytecode_module_signature.export_function_count);
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
// iree_vm_function_t function;
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
// auto function_name = iree_vm_function_name(&function);
// auto function_signature = iree_vm_function_signature(&function);
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
// (int)function_name.size, function_name.data,
// (int)function_signature.calling_convention.size,
// function_signature.calling_convention.data);
//}
// Allocate a context that will hold the module state across invocations.
iree_vm_context_t* iree_context = nullptr;
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
IREE_CHECK_OK(iree_vm_context_create_with_modules(
iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
iree_allocator_system(), &iree_context));
fprintf(stdout, "Context with modules is ready for use\n");
// Lookup the entry point function.
iree_vm_function_t main_function;
const char kMainFunctionName[] = "module.forward";
IREE_CHECK_OK(iree_vm_context_resolve_function(
iree_context,
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
&main_function));
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
fprintf(stdout, "Resolved main function named '%.*s'\n",
(int)main_function_name.size, main_function_name.data);
// --------------------------------------------------------------------------
// Write inputs into mappable buffers.
iree_hal_allocator_t* allocator =
iree_hal_device_allocator(iree_vk_device);
//iree_hal_memory_type_t input_memory_type =
// static_cast<iree_hal_memory_type_t>(
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
//iree_hal_buffer_usage_t input_buffer_usage =
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
//iree_hal_buffer_params_t buffer_params;
//buffer_params.type = input_memory_type;
//buffer_params.usage = input_buffer_usage;
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
// Wrap input buffers in buffer views.
vm::ref<iree_vm_list_t> inputs;
iree_status_t input_status = ParseToVariantList(
allocator,
iree::span<const std::string>{FLAG_function_inputs.data(),
FLAG_function_inputs.size()},
iree_allocator_system(), &inputs);
if (!iree_status_is_ok(input_status))
return -1;
//vm::ref<iree_vm_list_t> inputs;
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
// allocator,
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
// &input0_buffer_view));
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
// Prepare outputs list to accept results from the invocation.
vm::ref<iree_vm_list_t> outputs;
constexpr iree_hal_dim_t kOutputCount = 1000;
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
// --------------------------------------------------------------------------
// Main loop.
bool done = false;
while (!done) {
SDL_Event event;
while (SDL_PollEvent(&event)) {
if (event.type == SDL_QUIT) {
done = true;
}
ImGui_ImplSDL2_ProcessEvent(&event);
if (event.type == SDL_QUIT) done = true;
if (event.type == SDL_WINDOWEVENT &&
event.window.event == SDL_WINDOWEVENT_RESIZED &&
event.window.windowID == SDL_GetWindowID(window)) {
g_SwapChainResizeWidth = (int)event.window.data1;
g_SwapChainResizeHeight = (int)event.window.data2;
g_SwapChainRebuild = true;
}
}
if (g_SwapChainRebuild) {
g_SwapChainRebuild = false;
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
ImGui_ImplVulkanH_CreateOrResizeWindow(
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
g_SwapChainResizeHeight, g_MinImageCount);
g_MainWindowData.FrameIndex = 0;
}
// Start the Dear ImGui frame
ImGui_ImplVulkan_NewFrame();
ImGui_ImplSDL2_NewFrame(window);
ImGui::NewFrame();
// Custom window.
{
ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window);
ImGui::Separator();
// ImGui Inputs for two input tensors.
// Run computation whenever any of the values changes.
static bool dirty = true;
if (dirty) {
// Synchronously invoke the function.
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
IREE_VM_INVOCATION_FLAG_NONE,
/*policy=*/nullptr, inputs.get(),
outputs.get(), iree_allocator_system()));
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
dirty = true;
}
// Framerate counter.
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
ImGui::End();
}
// Rendering
ImGui::Render();
RenderFrame(wd, g_Device, g_Queue);
PresentFrame(wd, g_Queue);
}
// --------------------------------------------------------------------------
// --------------------------------------------------------------------------
// Cleanup
iree_vm_module_release(hal_module);
iree_vm_module_release(bytecode_module);
iree_vm_context_release(iree_context);
iree_hal_device_release(iree_vk_device);
iree_hal_allocator_release(allocator);
iree_hal_driver_release(iree_vk_driver);
iree_hal_vulkan_syms_release(iree_vk_syms);
iree_vm_instance_release(iree_instance);
err = vkDeviceWaitIdle(g_Device);
check_vk_result(err);
ImGui_ImplVulkan_Shutdown();
ImGui_ImplSDL2_Shutdown();
ImGui::DestroyContext();
CleanupVulkanWindow();
CleanupVulkan();
SDL_DestroyWindow(window);
SDL_Quit();
// --------------------------------------------------------------------------
return 0;
}
} // namespace iree

File diff suppressed because it is too large Load Diff

View File

@@ -18,7 +18,7 @@ import re
import tempfile
from pathlib import Path
import iree.runtime as ireert
#import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args
@@ -684,21 +684,21 @@ def get_results(
dl.log("Execution complete")
@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if "metal" in device and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
)
haldevice = haldriver.create_device_by_uri(
device,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator
if "metal" not in device
else None,
)
config = ireert.Config(device=haldevice)
return config
# @functools.cache
# def get_iree_runtime_config(device):
# device = iree_device_map(device)
# haldriver = ireert.get_driver(device)
# if "metal" in device and shark_args.device_allocator == "caching":
# print(
# "[WARNING] metal devices can not have a `caching` allocator."
# "\nUsing default allocator `None`"
# )
# haldevice = haldriver.create_device_by_uri(
# device,
# # metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
# allocators=shark_args.device_allocator
# if "metal" not in device
# else None,
# )
# config = ireert.Config(device=haldevice)
# return config