Files
SHARK-Studio/shark/iree_utils/compile_utils.py
Ean Garvey 0eff62a468 (Studio 2.0) add Stable Diffusion features (#2037)
* (WIP): Studio2 app infra and SD API

UI/app structure and utility implementation.

- Initializers for webui/API launch
- Schedulers file for SD scheduling utilities
- Additions to API-level utilities
- Added embeddings module for LoRA, Lycoris, yada yada
- Added image_processing module for resamplers, resize tools,
  transforms, and any image annotation (PNG metadata)
- shared_cmd_opts module -- sorry, this is stable_args.py. It lives on.
  We still want to have some global control over the app exclusively
  from the command-line. At least we will be free from shark_args.
- Moving around some utility pieces.
- Try to make api+webui concurrency possible in index.py
- SD UI -- this is just img2imgUI but hopefully a little better.
- UI utilities for your nod logos and your gradio temps.

Enable UI / bugfixes / tweaks

* Studio2/SD: Use more correct LoRA alpha calculation (#2034)

* Updates ProcessLoRA to use both embedded LoRA alpha, and lora_strength
optional parameter (default 1.0) when applying LoRA weights.
* Updates ProcessLoRA to cover more dim cases.
* This bring ProcessLoRA into line with PR #2015 against Studio1

* Studio2: Remove duplications from api/utils.py (#2035)

* Remove duplicate os import
* Remove duplicate parse_seed_input function

Migrating to JSON requests in SD UI

More UI and app flow improvements, logging, shared device cache

Model loading

Complete SD pipeline.

Tweaks to VAE, pipeline states

Pipeline tweaks, add cmd_opts parsing to sd api

* Add test for SD

* Small cleanup

* Shark2/SD/UI: Respect ckpt_dir, share and server_port args (#2070)

* Takes whether to generate a gradio live link from the existing --share command
line parameter, rather than hardcoding as True.
* Takes server port from existing --server_port command line parameter, rather than
hardcoding as 11911.
* Default --ckpt_dir parameter to '../models'
* Use --ckpt_dir rather than hardcoding ../models as the base directory for
checkpoints, vae, and lora, etc
* Add a 'checkpoints' directory below --ckpt_dir to match ComfyUI folder structure.
Read custom_weights choices from there, and/or subfolders below there matching
the selected base model.
* Fix --ckpt_dir possibly not working correctly when an absolute rather than relative path
is specified.
* Relabel "Custom Weights" to "Custom Weights Checkpoint" in the UI

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* HF-Reference LLM mode + Update test result to match latest Turbine. (#2080)

* HF-Reference LLM mode.

* Fixup test to match current output from Turbine.

* lint

* Fix test error message + Only initialize HF torch model when used.

* Remove redundant format_out change.

* Add rest API endpoint from LanguageModel API

* Add StreamingLLM support to studio2 chat (#2060)

* Streaming LLM

* Update precision and add gpu support

* (studio2) Separate weights generation for quantization support

* Adapt prompt changes to studio flow

* Remove outdated flag from llm compile flags.

* (studio2) use turbine vmfbRunner

* tweaks to prompts

* Update CPU path and llm api test.

* Change device in test to cpu.

* Fixes to runner, device names, vmfb mgmt

* Use small test without external weights.

* Formatting and init files.

* Remove unused import.

* Small fixes

* Studio2/SD/UI: Improve various parts of the UI for Stable Diffusion (#2074)

* Studio2/SD/UI: Improve various parts of the UI of Shark 2

* Update Gradio pin to 4.15.0.
* Port workarounds for Gradio >4.8.0 main container sizing from Shark 1.0.
* Move nod Logo out of the SD tab and onto the top right of the main tab bar.
* Set nod logo icon as the favicon (as current Shark 1.0).
* Create a tabbed right hand panel within the SD UI sized to the viewport height.
* Make Input Image tab 1 in the right hand panel.
* Make output images, generation log, and  generation buttons, tab 2 in the
right hand panel
* Make config JSON display, with config load, save and clear, tab 3 in the
right hand panel
* Make gallery  area of the Output tab take up all vertical space the other controls
on the tab do not.
* Tidy up the controls on the Config tab somewhat.

* Studio2/SD/UI: Reorganise inputs on Left Panel of SD tab

* Rename previously added Right Panel Output tab to 'Generate'.
* Move Batch Count, Batch Size, and Repeatable Seeds, off of Left Panel and onto 'Generate' Tab.
* On 'Generate' tab, rename 'Generate Image(s)' button to 'Start', and 'Stop Batch' button to 'Stop'. They are now below the Batch inputs on a Generate tab so don't need the specificity.
* Move Device, Low VRAM, and Precision inputs into their own 'Device Settings' Accordion control. (starts closed)
* Rename 'Custom Weights Checkpoint' to 'Checkpoint Weights'
* Move Checkpoint Weights, VAE Model, Standalone Lora Weights, and Embeddings Options controls, into their own 'Model Weights' Accordion control.  (starts closed)
* Move Denoising Strength, and Resample Type controls into their own 'Input Image Processing' Accordion. (starts closed)
* Move any remaining controls in the 'Advanced Options' Accorion directly onto the left panel, and remove then Accordion.
* Enable the copy button for all text boxes on the SD tab.
* Add emoji/unicode glphs to all top level controls and Accordions on the SD Left Panel.
* Start with the 'Generate' as the initially selected tab in the SD Right Panel, working around Gradio issue #7805
* Tweaks to SD Right Tab Panel vertical height.

* Studio2/SD/UI: Sizing tweaks for Right Panel, and >1920 width

* Set height of right panel using vmin rather than vh, with explicit affordances
for fixed areas above and below.
* Port >1920 width Gradio >4.8 CSS workaround from Shark 1.0.

* Studio2/SD: Fix sd pipeline up to "Windows not supported" (#2082)

* Studio2/SD: Fix sd pipeline up to "Windows not supported"

A number of fixes to the SD pipeline as run from the UI, up until the point that dynamo
complains "Windows not yet supported for torch.compile".

* Remove separate install of iree-runtime and iree-compile in setup_venv.ps1, and rely on the
versions installed via the Turbine requirements.txt. Fixes #2063 for me.
* Replace any "None" strings with python None when pulling the config in the UI.
* Add 'hf_auth_token' param to api StableDiffusion class, defaulting to None, and then pass
that in to the various Models where it is required and wasn't already being done before.
* Fix clip custom_weight_params being passed to export_clip_model as "external_weight_file"
rather than "external_weights"
* Don't pass non-existing "custom_vae" parameter to the Turbine Vae Model, instead
pass custom_vae as the "hf_model_id" if it is set. (this may be wrong in the custom vae
cast, but stops the code *always* breaking).

* Studio2/SD/UI: Improve UI config None handling

* When populating the UI from a JSON Config set controls to "None" for null/None
values.
* When generating a JSON Config from the UI set props to null/None for controls
set to "None".
* Use null rather string 'None' in the default config

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>

* Studio2/SD/UI: Further sd ui pipeline fixes (#2091)

On Windows, this gets us all the way failing in iree compile of the with SD 2.1 base.

- Fix merge errors with sd right pane config UI tab.
- Remove non-requirement.txt install/build of torch/mlir/iree/SRT in setup_venv.ps1, fixing "torch.compile not supported on Windows" error.
- Fix gradio deprecation warning for `root=` FileExplorer kwarg.
- Comment out `precision` and `max_length` kwargs being passed to unet, as not yet supported on main Turbine branch. Avoids keyword argument error.

* Tweak compile-time flags for SD submodels.

* Small fixes to sd, pin mpmath

* Add pyinstaller spec and imports script.

* Fix the .exe (#2101)

* Fix _IREE_TARGET_MAP (#2103) (#2108)

- Change target passed to iree for vulkan from 'vulkan'
to 'vulkan-spriv', as 'vulkan' is not a valid value for
--iree-hal-target-backends with the current iree compiler.

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>

* Cleanup sd model map.

* Update dependencies.

* Studio2/SD/UI: Update gradio to 4.19.2 (sd-studio2) (#2097)

- Move pin for gradio from 4.15 -> 4.19.2 on the sd-studio2 branch

* fix formatting and disable explicit vulkan env settings.

---------

Co-authored-by: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com>
Co-authored-by: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com>
Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com>
Co-authored-by: gpetters94 <gpetters@protonmail.com>
2024-03-29 18:13:21 -04:00

702 lines
24 KiB
Python

# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import numpy as np
import os
import re
import tempfile
from pathlib import Path
import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args
from .trace import DetailLogger
from ._common import iree_device_map, iree_target_map
from .cpu_utils import get_iree_cpu_rt_args
from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ u_kernel_flag
+ stack_size_flag
)
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def get_iree_target_triple(device):
args = get_iree_device_args(device)
for flag in args:
if "triple" in flag.split("-"):
triple = flag.split("=")
return triple
return ""
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = None
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-input-demote-i64-to-i32",
]
else:
# Frontend not found.
return []
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
"--iree-util-zero-fill-elided-attrs",
"--mlir-elide-elementsattrs-if-larger=10",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
# shark_args are passed down from pytests to control which models compile with these flags,
# but they can also be set in shark/parser.py
def get_model_specific_args():
ms_args = []
if shark_args.enable_conv_transform == True:
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
]
if shark_args.enable_img2col_transform == True:
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
]
if shark_args.use_winograd == True:
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
]
return ms_args
def create_dispatch_dirs(bench_dir, device):
protected_files = ["ordered-dispatches.txt"]
bench_dir_path = bench_dir.split("/")
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
tmp_bench_dir = "/".join(bench_dir_path)
for f_ in os.listdir(bench_dir):
if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files:
dir_name = re.sub("\.\S*$", "", f_)
if os.path.exists(f"{bench_dir}/{dir_name}"):
os.system(f"rm -rf {bench_dir}/{dir_name}")
os.system(f"mkdir {bench_dir}/{dir_name}")
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
for f_ in os.listdir(tmp_bench_dir):
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
dir_name = ""
for d_ in os.listdir(bench_dir):
if re.search(f"{d_}(?=\D)", f_):
dir_name = d_
if dir_name != "":
os.system(
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
)
def dump_isas(bench_dir):
for d_ in os.listdir(bench_dir):
if os.path.isdir(f"{bench_dir}/{d_}"):
for f_ in os.listdir(f"{bench_dir}/{d_}"):
if f_.endswith(".spv"):
os.system(
f"amdllpc -gfxip 11.0 {bench_dir}/{d_}/{f_} -v > \
{bench_dir}/{d_}/isa.txt"
)
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
benchmark_runtimes = {}
dispatch_list = []
all_dispatches = False
if dispatch_benchmarks.lower().strip() == "all":
all_dispatches = True
else:
try:
dispatch_list = [
int(dispatch_index)
for dispatch_index in dispatch_benchmarks.split(" ")
]
except:
print("ERROR: Invalid dispatch benchmarks")
return None
for d_ in os.listdir(bench_dir):
if os.path.isdir(f"{bench_dir}/{d_}"):
in_dispatches = False
for dispatch in dispatch_list:
if str(dispatch) in d_:
in_dispatches = True
if all_dispatches or in_dispatches:
for f_ in os.listdir(f"{bench_dir}/{d_}"):
if "benchmark.mlir" in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
module = dispatch_file.read()
dispatch_file.close()
flatbuffer_blob = ireec.compile_str(
module, target_backends=[iree_target_map(device)]
)
vmfb_file = open(
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
)
vmfb_file.write(flatbuffer_blob)
vmfb_file.close()
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)
benchmark_cl = build_benchmark_args_non_tensor_input(
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
device=device,
inputs=(0,),
mlir_dialect="linalg",
function_name="",
)
benchmark_bash = open(
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
)
benchmark_bash.write("#!/bin/bash\n")
benchmark_bash.write(" ".join(benchmark_cl))
benchmark_bash.close()
iter_per_second, _, _ = run_benchmark_module(
benchmark_cl
)
benchmark_file = open(
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
)
benchmark_file.write(f"DISPATCH: {d_}\n")
benchmark_file.write(str(iter_per_second) + "\n")
benchmark_file.write(
"SHARK BENCHMARK RESULT: "
+ str(1 / (iter_per_second * 0.001))
+ "\n"
)
benchmark_file.close()
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
elif ".mlir" in f_ and "benchmark" not in f_:
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
module = dispatch_file.read()
dispatch_file.close()
module = re.sub(
"hal.executable private",
"hal.executable public",
module,
)
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
extra_args=["--compile-mode=hal-executable"],
)
spirv_file = open(
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
)
spirv_file.write(flatbuffer_blob)
spirv_file.close()
ordered_dispatches = [
(k, v)
for k, v in sorted(
benchmark_runtimes.items(), key=lambda item: item[1]
)
][::-1]
f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+")
for dispatch in ordered_dispatches:
f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n")
f_.close()
def compile_module_to_flatbuffer(
module,
device,
frontend,
model_config_path,
extra_args,
model_name="None",
debug=False,
compile_str=False,
write_to=None,
):
# Setup Compile arguments wrt to frontends.
input_type = "auto"
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args(debug=debug)
args += get_model_specific_args()
args += extra_args
args += shark_args.additional_compile_args
if frontend in ["tensorflow", "tf"]:
input_type = "auto"
elif frontend in ["stablehlo", "tosa"]:
input_type = frontend
elif frontend in ["tflite", "tflite-tosa"]:
input_type = "tosa"
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
elif frontend in ["torch", "pytorch"]:
input_type = "torch"
if compile_str:
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
extra_args=args,
input_type=input_type,
)
else:
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
str(module),
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
)
if write_to is not None:
with open(write_to, "wb") as f:
f.write(flatbuffer_blob)
return None
return flatbuffer_blob
def get_iree_module(
flatbuffer_blob,
device,
device_idx=None,
rt_flags: list = [],
external_weight_file=None,
):
if external_weight_file is not None:
index = ireert.ParameterIndex()
index.load(external_weight_file)
# Returns the compiled module and the configs.
for flag in rt_flags:
ireert.flags.parse_flag(flag)
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
hal_device_id,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
config.id = hal_device_id
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
modules = []
if external_weight_file is not None:
modules.append(index.create_provider(scope="model"))
ctx = ireert.SystemContext(vm_modules=modules, config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config
def load_vmfb_using_mmap(
flatbuffer_blob_or_path,
device: str,
device_idx: int = None,
rt_flags: list = [],
external_weight_file: str = None,
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
rt_flags.append(flag)
for flag in rt_flags:
print(flag)
ireert.flags.parse_flags(flag)
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
# First get configs.
if device_idx is not None:
dl.log(f"Mapping device id: {device_idx}")
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
hal_device_id = haldriver.query_available_devices()[device_idx][
"device_id"
]
haldevice = haldriver.create_device(
hal_device_id,
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = hal_device_id
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
dl.log("get_iree_runtime_config")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
vm_modules = []
if external_weight_file is not None:
index = ireert.ParameterIndex()
index.load(external_weight_file)
param_module = ireert.create_io_parameters_module(
config.vm_instance, index.create_provider(scope="model")
)
vm_modules.append(param_module)
vm_modules.append(mmaped_vmfb)
vm_modules.append(
ireert.create_hal_module(config.vm_instance, config.device)
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
print(
"\tCompiling Vulkan shaders. This may take a few minutes."
)
ctx = ireert.SystemContext(config=config, vm_modules=vm_modules)
dl.log(f"ireert.SystemContext created")
for flag in shark_args.additional_runtime_args:
ireert.flags.parse_flags(flag)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
dl.log(f"mmap temp {vmfb_file_path}")
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
rt_flags: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
external_weight_file: str = None,
write_to: bool = None,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module=module,
device=device,
frontend=frontend,
model_config_path=model_config_path,
extra_args=extra_args,
debug=debug,
compile_str=compile_str,
write_to=write_to,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
if write_to is not None:
flatbuffer_blob = write_to
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob,
device,
device_idx,
rt_flags,
external_weight_file=external_weight_file,
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
external_weight_file=external_weight_file,
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
rt_flags: list = [],
):
temp_file_to_unlink = None
if mmap:
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx, rt_flags
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob,
device,
device_idx=device_idx,
rt_flags=rt_flags,
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params
def export_iree_module_to_vmfb(
module,
device: str,
directory: str,
mlir_dialect: str = "linalg",
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
compile_str: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module=module,
device=device,
frontend=mlir_dialect,
model_config_path=model_config_path,
extra_args=extra_args,
debug=debug,
compile_str=compile_str,
)
if module_name is None:
device_name = (
device if "://" not in device else "-".join(device.split("://"))
)
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
with open(filename, "wb") as f:
f.write(flatbuffer_blob)
print(f"Saved vmfb in {filename}.")
return filename
def export_module_to_mlir_file(module, frontend, directory: str):
# TODO: write proper documentation.
mlir_str = module
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
mlir_str = module.decode("utf-8")
elif frontend in ["pytorch", "torch"]:
mlir_str = module.operation.get_asm()
filename = os.path.join(directory, "model.mlir")
with open(filename, "w") as f:
f.write(mlir_str)
print(f"Saved mlir in {filename}.")
return filename
def get_results(
compiled_vm,
function_name,
input,
config,
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
device: str = None,
):
"""Runs a .vmfb file given inputs and config and returns output."""
with DetailLogger(debug_timeout) as dl:
device_inputs = []
if device == "rocm" and hasattr(config, "id"):
haldriver = ireert.get_driver("rocm")
haldevice = haldriver.create_device(
config.id,
allocators=shark_args.device_allocator,
)
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
ireert.asdevicearray(config.device, input_array)
)
dl.log(f"Invoke function: {function_name}")
result = compiled_vm[function_name](*device_inputs)
dl.log(f"Invoke complete")
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
dl.log(f"Result to host: {val.shape}")
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
if send_to_host and result is not None:
dl.log("Result to host")
return result.to_host()
return result
dl.log("Execution complete")
@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if "metal" in device and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
)
haldevice = haldriver.create_device_by_uri(
device,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator
if "metal" not in device
else None,
)
config = ireert.Config(device=haldevice)
return config