mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 14:58:11 -05:00
Compare commits
3 Commits
20240606.1
...
debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4529fd0461 | ||
|
|
4c2bb4b7b4 | ||
|
|
d5013fd13e |
@@ -25,14 +25,6 @@ def imports():
|
||||
)
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torch")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers")
|
||||
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers")
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=FutureWarning, module="huggingface-hub"
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
action="ignore", category=UserWarning, module="huggingface-hub"
|
||||
)
|
||||
|
||||
import gradio # noqa: F401
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ from tqdm.auto import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
SharkSDXLPipeline,
|
||||
)
|
||||
|
||||
|
||||
from apps.shark_studio.api.controlnet import control_adapter_map
|
||||
@@ -28,8 +31,11 @@ from apps.shark_studio.modules.img_processing import (
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
from apps.shark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
|
||||
from subprocess import check_output
|
||||
EMPTY_SD_MAP = {
|
||||
"clip": None,
|
||||
"scheduler": None,
|
||||
@@ -61,6 +67,7 @@ def load_script(source, module_name):
|
||||
:param module_name: name of module to register in sys.modules
|
||||
:return: loaded module
|
||||
"""
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, source)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
@@ -92,10 +99,7 @@ class StableDiffusion:
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
external_weights: str = "safetensors",
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(0, desc="Initializing pipeline...")
|
||||
self.ui_device = device
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
self.base_model_id = base_model_id
|
||||
@@ -108,20 +112,12 @@ class StableDiffusion:
|
||||
"custom_pipeline",
|
||||
)
|
||||
self.turbine_pipe = custom_module.StudioPipeline
|
||||
self.dynamic_steps = False
|
||||
self.model_map = custom_module.MODEL_MAP
|
||||
elif self.is_sdxl:
|
||||
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
|
||||
SharkSDXLPipeline,
|
||||
)
|
||||
self.turbine_pipe = SharkSDXLPipeline
|
||||
self.dynamic_steps = False
|
||||
self.model_map = EMPTY_SDXL_MAP
|
||||
else:
|
||||
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
|
||||
|
||||
self.turbine_pipe = SharkSDPipeline
|
||||
self.dynamic_steps = True
|
||||
self.model_map = EMPTY_SD_MAP
|
||||
max_length = 64
|
||||
target_backend, self.rt_device, triple = parse_device(device, target_triple)
|
||||
@@ -162,7 +158,7 @@ class StableDiffusion:
|
||||
external_weights = None
|
||||
elif target_backend == "llvm-cpu":
|
||||
decomp_attn = False
|
||||
progress(0.5, desc="Initializing pipeline...")
|
||||
|
||||
self.sd_pipe = self.turbine_pipe(
|
||||
hf_model_name=base_model_id,
|
||||
scheduler_id=scheduler,
|
||||
@@ -182,20 +178,13 @@ class StableDiffusion:
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
)
|
||||
progress(1, desc="Pipeline initialized!...")
|
||||
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(
|
||||
self,
|
||||
custom_weights,
|
||||
adapters,
|
||||
embeddings,
|
||||
is_img2img,
|
||||
compiled_pipeline,
|
||||
progress=gr.Progress(),
|
||||
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
|
||||
):
|
||||
progress(0, desc="Preparing models...")
|
||||
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = False
|
||||
mlirs = copy.deepcopy(self.model_map)
|
||||
vmfbs = copy.deepcopy(self.model_map)
|
||||
@@ -205,10 +194,6 @@ class StableDiffusion:
|
||||
self.compiled_pipeline = compiled_pipeline
|
||||
|
||||
if custom_weights:
|
||||
from apps.shark_studio.modules.ckpt_processing import (
|
||||
preprocessCKPT,
|
||||
save_irpa,
|
||||
)
|
||||
custom_weights = os.path.join(
|
||||
get_checkpoints_path("checkpoints"),
|
||||
safe_name(self.base_model_id.split("/")[-1]),
|
||||
@@ -251,18 +236,17 @@ class StableDiffusion:
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
progress(0.25, desc=f"Preparing pipeline for {self.ui_device}...")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
mlirs, vmfbs, weights, interactive=False
|
||||
)
|
||||
progress(0.5, desc=f"Artifacts ready!")
|
||||
progress(0.75, desc=f"Loading models and weights...")
|
||||
|
||||
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
)
|
||||
progress(1, desc="Pipeline loaded! Generating images...")
|
||||
print(
|
||||
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
|
||||
)
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
@@ -277,9 +261,7 @@ class StableDiffusion:
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -291,7 +273,9 @@ class StableDiffusion:
|
||||
return img
|
||||
|
||||
|
||||
def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
def shark_sd_fn_dict_input(
|
||||
sd_kwargs: dict,
|
||||
):
|
||||
print("\n[LOG] Submitting Request...")
|
||||
|
||||
for key in sd_kwargs:
|
||||
@@ -299,8 +283,6 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
sd_kwargs[key] = None
|
||||
if sd_kwargs[key] in ["None"]:
|
||||
sd_kwargs[key] = ""
|
||||
if key in ["steps", "height", "width", "batch_count", "batch_size"]:
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
if key == "seed":
|
||||
sd_kwargs[key] = int(sd_kwargs[key])
|
||||
|
||||
@@ -324,7 +306,7 @@ def shark_sd_fn_dict_input(sd_kwargs: dict, *, progress=gr.Progress()):
|
||||
)
|
||||
return None, ""
|
||||
if sd_kwargs["target_triple"] == "":
|
||||
if not parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2]:
|
||||
if parse_device(sd_kwargs["device"], sd_kwargs["target_triple"])[2] == "":
|
||||
gr.Warning(
|
||||
"Target device architecture could not be inferred. Please specify a target triple, e.g. 'gfx1100' for a Radeon 7900xtx."
|
||||
)
|
||||
@@ -358,8 +340,6 @@ def shark_sd_fn(
|
||||
resample_type: str,
|
||||
controlnets: dict,
|
||||
embeddings: dict,
|
||||
seed_increment: str | int = 1,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
@@ -433,9 +413,6 @@ def shark_sd_fn(
|
||||
"control_mode": control_mode,
|
||||
"hints": hints,
|
||||
}
|
||||
if global_obj.get_sd_obj() and global_obj.get_sd_obj().dynamic_steps:
|
||||
submit_run_kwargs["steps"] = submit_pipe_kwargs["steps"]
|
||||
submit_pipe_kwargs.pop("steps")
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_pipe_kwargs() != submit_pipe_kwargs
|
||||
@@ -461,12 +438,6 @@ def shark_sd_fn(
|
||||
global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs)
|
||||
|
||||
generated_imgs = []
|
||||
if submit_run_kwargs["seed"] in [-1, "-1"]:
|
||||
submit_run_kwargs["seed"] = randint(0, 4294967295)
|
||||
seed_increment = "random"
|
||||
#print(f"\n[LOG] Random seed: {seed}")
|
||||
progress(None, desc=f"Generating...")
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
@@ -485,26 +456,14 @@ def shark_sd_fn(
|
||||
sd_kwargs,
|
||||
)
|
||||
generated_imgs.extend(out_imgs)
|
||||
|
||||
# TODO: make seed changes over batch counts more configurable.
|
||||
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
|
||||
yield generated_imgs, status_label(
|
||||
"Stable Diffusion", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
if batch_count > 1:
|
||||
submit_run_kwargs["seed"] = get_next_seed(seed, seed_increment)
|
||||
|
||||
return (generated_imgs, "")
|
||||
|
||||
|
||||
def get_next_seed(seed, seed_increment: str | int = 10):
|
||||
if isinstance(seed_increment, int):
|
||||
#print(f"\n[LOG] Seed after batch increment: {seed + seed_increment}")
|
||||
return int(seed + seed_increment)
|
||||
elif seed_increment == "random":
|
||||
seed = randint(0, 4294967295)
|
||||
#print(f"\n[LOG] Random seed: {seed}")
|
||||
return seed
|
||||
|
||||
|
||||
def unload_sd():
|
||||
print("Unloading models.")
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
@@ -536,11 +495,11 @@ if __name__ == "__main__":
|
||||
global_obj._init()
|
||||
|
||||
sd_json = view_json_file(
|
||||
get_resource_path(os.path.join(cmd_opts.config_dir, cmd_opts.default_config))
|
||||
get_resource_path(os.path.join(cmd_opts.config_dir, "default_sd_config.json"))
|
||||
)
|
||||
sd_kwargs = json.loads(sd_json)
|
||||
# for arg in vars(cmd_opts):
|
||||
# if arg in sd_kwargs:
|
||||
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for i in shark_sd_fn_dict_input(sd_kwargs):
|
||||
print(i)
|
||||
|
||||
@@ -11,62 +11,17 @@ from pathlib import Path
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ['rocm', 'cpu']
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
@@ -94,36 +49,36 @@ def get_available_devices():
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
#set_iree_runtime_flags()
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
# cpu_device = get_devices_by_name("cpu-sync")
|
||||
# available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
|
||||
# from shark.iree_utils.vulkan_utils import (
|
||||
# get_all_vulkan_devices,
|
||||
# )
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
# vulkaninfo_list = get_all_vulkan_devices()
|
||||
# vulkan_devices = []
|
||||
# id = 0
|
||||
# for device in vulkaninfo_list:
|
||||
# vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
# id += 1
|
||||
# if id != 0:
|
||||
# print(f"vulkan devices are available.")
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
|
||||
# available_devices.extend(vulkan_devices)
|
||||
# metal_devices = get_devices_by_name("metal")
|
||||
# available_devices.extend(metal_devices)
|
||||
# cuda_devices = get_devices_by_name("cuda")
|
||||
# available_devices.extend(cuda_devices)
|
||||
# hip_devices = get_devices_by_name("hip")
|
||||
# available_devices.extend(hip_devices)
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
hip_devices = get_devices_by_name("hip")
|
||||
available_devices.extend(hip_devices)
|
||||
|
||||
for idx, device_str in enumerate(available_devices):
|
||||
if "AMD Radeon(TM) Graphics =>" in device_str:
|
||||
@@ -140,29 +95,62 @@ def get_available_devices():
|
||||
break
|
||||
return available_devices
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in cmd_opts.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in cmd_opts.device:
|
||||
cmd_opts.device = "cuda"
|
||||
elif "metal" in cmd_opts.device:
|
||||
device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device)
|
||||
if not cmd_opts.iree_metal_target_platform:
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
|
||||
triple = get_metal_target_triple(device_name)
|
||||
if triple is not None:
|
||||
cmd_opts.iree_metal_target_platform = triple.split("-")[-1]
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple "
|
||||
f"{cmd_opts.iree_metal_target_platform}."
|
||||
)
|
||||
elif "cpu" in cmd_opts.device:
|
||||
cmd_opts.device = "cpu"
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
# TODO: This function should be device-agnostic and piped properly
|
||||
# to general runtime driver init.
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if cmd_opts.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
if cmd_opts.device_allocator_heap_key:
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
if device not in ["hip", "rocm", "vulkan"]:
|
||||
device_id = None
|
||||
if device in ["hip", "rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
def parse_device(device_str, target_override=""):
|
||||
from shark.iree_utils.compile_utils import (
|
||||
clean_device_info,
|
||||
get_iree_target_triple,
|
||||
iree_target_map,
|
||||
)
|
||||
|
||||
rt_driver, device_id = clean_device_info(device_str)
|
||||
target_backend = iree_target_map(rt_driver)
|
||||
@@ -172,8 +160,6 @@ def parse_device(device_str, target_override=""):
|
||||
rt_device = rt_driver
|
||||
|
||||
if target_override:
|
||||
if "cpu" in device_str:
|
||||
rt_device = "local-task"
|
||||
return target_backend, rt_device, target_override
|
||||
match target_backend:
|
||||
case "vulkan-spirv":
|
||||
@@ -183,10 +169,7 @@ def parse_device(device_str, target_override=""):
|
||||
triple = get_rocm_target_chip(device_str)
|
||||
return target_backend, rt_device, triple
|
||||
case "llvm-cpu":
|
||||
if "Ryzen 9" in device_str:
|
||||
return target_backend, "local-task", "znver4"
|
||||
else:
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
return "llvm-cpu", "local-task", "x86_64-linux-gnu"
|
||||
|
||||
|
||||
def get_rocm_target_chip(device_str):
|
||||
@@ -208,7 +191,9 @@ def get_rocm_target_chip(device_str):
|
||||
for key in rocm_chip_map:
|
||||
if key in device_str:
|
||||
return rocm_chip_map[key]
|
||||
return None
|
||||
raise AssertionError(
|
||||
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
|
||||
)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
@@ -222,69 +207,183 @@ def get_all_devices(driver_name):
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
del driver
|
||||
return device_list_src
|
||||
|
||||
|
||||
# def get_device_mapping(driver, key_combination=3):
|
||||
# """This method ensures consistent device ordering when choosing
|
||||
# specific devices for execution
|
||||
# Args:
|
||||
# driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
# key_combination (int, optional): choice for mapping value for
|
||||
# device name.
|
||||
# 1 : path
|
||||
# 2 : name
|
||||
# 3 : (name, path)
|
||||
# Defaults to 3.
|
||||
# Returns:
|
||||
# dict: map to possible device names user can input mapped to desired
|
||||
# combination of name/path.
|
||||
# """
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired
|
||||
combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
# driver = iree_device_map(driver)
|
||||
# device_list = get_all_devices(driver)
|
||||
# device_map = dict()
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
# def get_output_value(dev_dict):
|
||||
# if key_combination == 1:
|
||||
# return f"{driver}://{dev_dict['path']}"
|
||||
# if key_combination == 2:
|
||||
# return dev_dict["name"]
|
||||
# if key_combination == 3:
|
||||
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return dev_dict["name"], f"{driver}://{dev_dict['path']}"
|
||||
|
||||
# # mapping driver name to default device (driver://0)
|
||||
# device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
# for i, device in enumerate(device_list):
|
||||
# # mapping with index
|
||||
# device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# # mapping with full path
|
||||
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
# return device_map
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
# def get_opt_flags(model, precision="fp16"):
|
||||
# iree_flags = []
|
||||
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
# iree_flags.append(
|
||||
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
# )
|
||||
# if "rocm" in cmd_opts.device:
|
||||
# from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags = []
|
||||
if len(cmd_opts.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
|
||||
)
|
||||
if "rocm" in cmd_opts.device:
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
# rocm_args = get_iree_rocm_args()
|
||||
# iree_flags.extend(rocm_args)
|
||||
# if cmd_opts.iree_constant_folding == False:
|
||||
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
# iree_flags.append(
|
||||
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
# )
|
||||
# if cmd_opts.data_tiling == False:
|
||||
# iree_flags.append("--iree-opt-data-tiling=False")
|
||||
rocm_args = get_iree_rocm_args()
|
||||
iree_flags.extend(rocm_args)
|
||||
if cmd_opts.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if cmd_opts.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
# if "vae" not in model:
|
||||
# # Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# # dims before dispatch formation right now.
|
||||
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
# return iree_flags
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user
|
||||
selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for
|
||||
device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for
|
||||
the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
cpu_name = get_cpu_info()["brand_raw"]
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
)
|
||||
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
vulkan_devices = []
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
|
||||
id += 1
|
||||
if id != 0:
|
||||
print(f"vulkan devices are available.")
|
||||
available_devices.extend(vulkan_devices)
|
||||
metal_devices = get_devices_by_name("metal")
|
||||
available_devices.extend(metal_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
rocm_devices = get_devices_by_name("rocm")
|
||||
available_devices.extend(rocm_devices)
|
||||
cpu_device = get_devices_by_name("cpu-sync")
|
||||
available_devices.extend(cpu_device)
|
||||
cpu_device = get_devices_by_name("cpu-task")
|
||||
available_devices.extend(cpu_device)
|
||||
return available_devices
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
@@ -71,14 +71,7 @@ def save_irpa(weights_path, prepend_str):
|
||||
new_key = prepend_str + key
|
||||
archive.add_tensor(new_key, weights[key])
|
||||
|
||||
if "safetensors" in weights_path:
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
elif "irpa" in weights_path:
|
||||
irpa_file = weights_path
|
||||
else:
|
||||
return Exception(
|
||||
"Invalid file format. Please provide a .safetensors or .irpa file."
|
||||
)
|
||||
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
||||
archive.save(irpa_file)
|
||||
return irpa_file
|
||||
|
||||
|
||||
@@ -33,8 +33,6 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
if extra_info is None:
|
||||
extra_info = {}
|
||||
elif "progress" in extra_info.keys():
|
||||
extra_info.pop("progress")
|
||||
generated_imgs_path = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
|
||||
@@ -101,7 +101,7 @@ def export_scheduler_model(model):
|
||||
|
||||
|
||||
scheduler_model_map = {
|
||||
# "PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
"PNDM": export_scheduler_model("PNDMScheduler"),
|
||||
# "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"),
|
||||
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
|
||||
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
|
||||
|
||||
@@ -23,6 +23,7 @@ p = argparse.ArgumentParser(
|
||||
##############################################################################
|
||||
# Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"-a",
|
||||
"--app",
|
||||
@@ -34,7 +35,10 @@ p.add_argument(
|
||||
"--prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smoke coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
@@ -58,7 +62,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=2,
|
||||
default=50,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
@@ -96,7 +100,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=0,
|
||||
default=7.5,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
@@ -593,12 +597,6 @@ p.add_argument(
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
p.add_argument(
|
||||
"--defaults",
|
||||
default="sdxl-turbo.json",
|
||||
type=str,
|
||||
help="Path to the default API request .json file. Works for CLI and webui."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--webui",
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from apps.shark_studio.modules.ckpt_processing import save_irpa
|
||||
import argparse
|
||||
import safetensors
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="",
|
||||
help="input safetensors/irpa",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default="",
|
||||
help="prefix to add to all the keys in the irpa",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
output_file = save_irpa(args.input, args.prefix)
|
||||
print("saved irpa to", output_file, "with prefix", args.prefix)
|
||||
@@ -83,7 +83,7 @@ def webui():
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
#from ui.chat import chat_element
|
||||
from ui.chat import chat_element
|
||||
from ui.sd import sd_element
|
||||
from ui.outputgallery import outputgallery_element
|
||||
|
||||
@@ -170,7 +170,7 @@ def webui():
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="Shark Studio 2.0",
|
||||
title="Shark Studio 2.0 Beta",
|
||||
) as studio_web:
|
||||
amd_logo = Image.open(amdlogo_loc)
|
||||
gr.Image(
|
||||
@@ -194,8 +194,8 @@ def webui():
|
||||
sd_element.render()
|
||||
with gr.TabItem(label="Output Gallery", id=1):
|
||||
outputgallery_element.render()
|
||||
# with gr.TabItem(label="Chat Bot", id=2):
|
||||
# chat_element.render()
|
||||
with gr.TabItem(label="Chat Bot", id=2):
|
||||
chat_element.render()
|
||||
|
||||
studio_web.queue()
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
get_configs_path,
|
||||
get_configs,
|
||||
write_default_sd_configs,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
@@ -45,15 +44,13 @@ from apps.shark_studio.web.ui.common_events import lora_changed
|
||||
from apps.shark_studio.modules import logger
|
||||
import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
# Disabled some models for demo purposes
|
||||
sd_default_models = [
|
||||
# "runwayml/stable-diffusion-v1-5",
|
||||
# "stabilityai/stable-diffusion-2-1-base",
|
||||
# "stabilityai/stable-diffusion-2-1",
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
sd_default_models.extend(get_checkpoints(model_type="scripts"))
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
@@ -149,14 +146,7 @@ def pull_sd_configs(
|
||||
|
||||
|
||||
def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
if os.path.exists(load_sd_config):
|
||||
config = load_sd_config
|
||||
elif os.path.exists(os.path.join(get_configs_path(), load_sd_config)):
|
||||
config = os.path.join(get_configs_path(), load_sd_config)
|
||||
else:
|
||||
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
|
||||
config = sd_json
|
||||
new_sd_config = none_to_str_none(json.loads(view_json_file(config)))
|
||||
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
|
||||
if sd_json:
|
||||
for key in new_sd_config:
|
||||
sd_json[key] = new_sd_config[key]
|
||||
@@ -168,8 +158,6 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_image = [Image.open(i, mode="r")]
|
||||
else:
|
||||
sd_image = None
|
||||
if not sd_json["device"]:
|
||||
sd_json["device"] = gr.update()
|
||||
|
||||
return [
|
||||
sd_json["prompt"][0],
|
||||
@@ -177,7 +165,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
sd_image,
|
||||
sd_json["height"],
|
||||
sd_json["width"],
|
||||
gr.update(),
|
||||
sd_json["steps"],
|
||||
sd_json["strength"],
|
||||
sd_json["guidance_scale"],
|
||||
sd_json["seed"],
|
||||
@@ -210,7 +198,7 @@ def save_sd_cfg(config: dict, save_name: str):
|
||||
filepath += ".json"
|
||||
with open(filepath, mode="w") as f:
|
||||
f.write(json.dumps(config))
|
||||
return save_name
|
||||
return "..."
|
||||
|
||||
|
||||
def create_canvas(width, height):
|
||||
@@ -247,249 +235,89 @@ def base_model_changed(base_model_id):
|
||||
new_choices = get_checkpoints(
|
||||
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
|
||||
) + get_checkpoints(model_type="checkpoints")
|
||||
if "turbo" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=2,
|
||||
choices=[1, 2],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
if "stable-diffusion-xl-base-1.0" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=40,
|
||||
choices=[20, 25, 30, 35, 40, 45, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
elif ".py" in base_model_id:
|
||||
new_steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 15, 20],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
else:
|
||||
new_steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 20, 30, 40, 50],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
),
|
||||
new_steps,
|
||||
]
|
||||
return gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
)
|
||||
|
||||
init_config = global_obj.get_init_config()
|
||||
init_config = none_to_str_none(json.loads(view_json_file(init_config)))
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Column(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2, min_width=600):
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="\U00002795\U0000FE0F Prompt",
|
||||
value=init_config["prompt"][0],
|
||||
lines=4,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="\U00002796\U0000FE0F Negative Prompt",
|
||||
value=init_config["negative_prompt"][0],
|
||||
lines=4,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U0001F4D0\U0000FE0F Advanced Settings", open=False
|
||||
label="\U0001F4D0\U0000FE0F Device Settings", open=False
|
||||
):
|
||||
with gr.Accordion(
|
||||
label="Device Settings", open=False
|
||||
):
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=init_config["device"] if init_config["device"] else "rocm",
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="target_triple",
|
||||
label="Architecture",
|
||||
value=init_config["target_triple"],
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=init_config["ondemand"],
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=init_config["precision"],
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=global_obj.get_device_list()[0],
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
target_triple = gr.Textbox(
|
||||
elem_id="target_triple",
|
||||
label="Architecture",
|
||||
value="",
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=512,
|
||||
step=512,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
interactive=False, # DEMO
|
||||
visible=False, # DEMO
|
||||
)
|
||||
|
||||
with gr.Accordion(
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing",
|
||||
open=False,
|
||||
visible=False,
|
||||
):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=init_config["strength"],
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=init_config["resample_type"],
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
base_model_id = gr.Dropdown(
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value=init_config["base_model_id"],
|
||||
choices=sd_default_models,
|
||||
allow_custom_value=True,
|
||||
) # base_model_id
|
||||
with gr.Row(equal_height=True):
|
||||
seed = gr.Textbox(
|
||||
value=init_config["seed"],
|
||||
label="\U0001F331\U0000FE0F Seed",
|
||||
info="An integer, -1 for random",
|
||||
show_copy_button=True,
|
||||
)
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="\U0001F4C5\U0000FE0F Scheduler",
|
||||
info="\U000E0020", # forces same height as seed
|
||||
value=init_config["scheduler"],
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Dropdown(
|
||||
value=20,
|
||||
choices=[10, 15, 20],
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
5, #DEMO
|
||||
value=4,
|
||||
step=0.1,
|
||||
label="\U0001F5C3\U0000FE0F CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=init_config["batch_count"],
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=init_config["batch_size"],
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False, # DEMO
|
||||
visible=True,
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
value=init_config["compiled_pipeline"],
|
||||
label="Faster txt2img (SDXL only)",
|
||||
visible=False, # DEMO
|
||||
)
|
||||
with gr.Row(elem_classes=["fill"], visible=False):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
default_config_file = global_obj.get_init_config()
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
load_sd_config = gr.Dropdown(
|
||||
label="Load Config",
|
||||
value=cmd_opts.defaults,
|
||||
choices=get_configs(),
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config",
|
||||
size="sm",
|
||||
components=sd_json,
|
||||
)
|
||||
# with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
base_model_id = gr.Dropdown(
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_default_models,
|
||||
allow_custom_value=True,
|
||||
) # base_model_id
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
1024,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
1024,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U00002696\U0000FE0F Model Weights",
|
||||
open=False,
|
||||
visible=False, # DEMO
|
||||
label="\U00002696\U0000FE0F Model Weights", open=False
|
||||
):
|
||||
with gr.Column():
|
||||
custom_weights = gr.Dropdown(
|
||||
label="Checkpoint Weights",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value=init_config["custom_weights"],
|
||||
value="None",
|
||||
allow_custom_value=True,
|
||||
choices=["None"]
|
||||
+ get_checkpoints(os.path.basename(str(base_model_id))),
|
||||
) # custom_weights
|
||||
base_model_id.change(
|
||||
fn=base_model_changed,
|
||||
inputs=[base_model_id],
|
||||
outputs=[custom_weights],
|
||||
)
|
||||
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
@@ -498,7 +326,11 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label=f"VAE Model",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=init_config["custom_vae"],
|
||||
value=(
|
||||
os.path.basename(cmd_opts.custom_vae)
|
||||
if cmd_opts.custom_vae
|
||||
else "None"
|
||||
),
|
||||
choices=["None"] + get_checkpoints("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
@@ -511,7 +343,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=sd_lora_info,
|
||||
elem_id="lora_weights",
|
||||
value=init_config["embeddings"][0] if (len(init_config["embeddings"].keys()) > 1) else "None",
|
||||
value=None,
|
||||
multiselect=True,
|
||||
choices=[] + get_checkpoints("lora"),
|
||||
scale=2,
|
||||
@@ -536,6 +368,67 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
outputs=[embeddings_config],
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
|
||||
):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=cmd_opts.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=cmd_opts.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="\U00002795\U0000FE0F Prompt",
|
||||
value=cmd_opts.prompt[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="\U00002796\U0000FE0F Negative Prompt",
|
||||
value=cmd_opts.negative_prompt[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
label="\U0001F331\U0000FE0F Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
show_copy_button=True,
|
||||
)
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="\U0001F4C5\U0000FE0F Scheduler",
|
||||
info="\U000E0020", # forces same height as seed
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="\U0001F5C3\U0000FE0F CFG Scale",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="Controlnet Options",
|
||||
open=False,
|
||||
@@ -585,17 +478,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=512,
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=512,
|
||||
step=8,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=512,
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=512,
|
||||
step=8,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
@@ -665,9 +558,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
lambda: gr.Tabs(selected=101),
|
||||
outputs=[sd_tabs],
|
||||
)
|
||||
with gr.Tab(
|
||||
label="Input Image", id=100, visible=False
|
||||
) as sd_tab_init_image: # DEMO
|
||||
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
@@ -697,6 +588,28 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
unload = gr.Button("Unload Models")
|
||||
@@ -705,44 +618,91 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop", visible=False)
|
||||
# with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
# with gr.Group():#elem_classes=["sd-right-panel"]):
|
||||
# with gr.Row(elem_classes=["fill"], visible=False):
|
||||
# Path(get_configs_path()).mkdir(
|
||||
# parents=True, exist_ok=True
|
||||
# )
|
||||
# write_default_sd_configs(get_configs_path())
|
||||
# default_config_file = global_obj.get_init_config()
|
||||
# sd_json = gr.JSON(
|
||||
# elem_classes=["fill"],
|
||||
# value=view_json_file(default_config_file),
|
||||
# )
|
||||
# with gr.Row():
|
||||
# with gr.Row():
|
||||
# load_sd_config = gr.Dropdown(
|
||||
# label="Load Config",
|
||||
# value=cmd_opts.defaults,
|
||||
# choices=get_configs(),
|
||||
# allow_custom_value=True,
|
||||
# )
|
||||
# with gr.Row():
|
||||
# save_sd_config = gr.Button(
|
||||
# value="Save Config", size="sm"
|
||||
# )
|
||||
# clear_sd_config = gr.ClearButton(
|
||||
# value="Clear Config",
|
||||
# size="sm",
|
||||
# components=sd_json,
|
||||
# )
|
||||
# # with gr.Row():
|
||||
# sd_config_name = gr.Textbox(
|
||||
# value="Config Name",
|
||||
# info="Name of the file this config will be saved to.",
|
||||
# interactive=True,
|
||||
# show_label=False,
|
||||
# )
|
||||
with gr.Tab(label="Log", id=103, visible=False) as sd_tab_log:
|
||||
stop_batch = gr.Button("Stop")
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
)
|
||||
write_default_sd_configs(get_configs_path())
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
file_count="single",
|
||||
root_dir=(
|
||||
cmd_opts.configs_path
|
||||
if cmd_opts.configs_path
|
||||
else get_configs_path()
|
||||
),
|
||||
height=75,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config",
|
||||
size="sm",
|
||||
components=sd_json,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
with gr.Tab(label="Log", id=103) as sd_tab_log:
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
@@ -758,46 +718,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
base_model_id.change(
|
||||
fn=base_model_changed,
|
||||
inputs=[base_model_id],
|
||||
outputs=[custom_weights, steps],
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
target_triple,
|
||||
ondemand,
|
||||
compiled_pipeline,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
inputs=[
|
||||
|
||||
@@ -89,7 +89,7 @@ sdxl_turbo = r"""{
|
||||
}"""
|
||||
|
||||
default_sd_configs = {
|
||||
# "default_sd_config.json": sdxl_turbo,
|
||||
# "sdxl-30steps.json": sdxl_30steps,
|
||||
"default_sd_config.json": default_sd_config,
|
||||
"sdxl-30steps.json": sdxl_30steps,
|
||||
"sdxl-turbo.json": sdxl_turbo,
|
||||
}
|
||||
|
||||
@@ -17,9 +17,8 @@ from apps.shark_studio.web.utils.default_configs import default_sd_configs
|
||||
def write_default_sd_configs(path):
|
||||
for key in default_sd_configs.keys():
|
||||
config_fpath = os.path.join(path, key)
|
||||
if not os.path.exists(config_fpath):
|
||||
with open(config_fpath, "w") as f:
|
||||
f.write(default_sd_configs[key])
|
||||
with open(config_fpath, "w") as f:
|
||||
f.write(default_sd_configs[key])
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
@@ -88,8 +87,6 @@ def get_checkpoints_path(model_type=""):
|
||||
def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model_type == "scripts":
|
||||
file_types = ["shark_*.py"]
|
||||
if model_type == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
@@ -100,15 +97,6 @@ def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
def get_configs():
|
||||
return sorted(
|
||||
[
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_configs_path(), "*.json"))
|
||||
],
|
||||
key=str.casefold,
|
||||
)
|
||||
|
||||
|
||||
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
|
||||
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
import gc
|
||||
from ...api.utils import get_available_devices
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
import os
|
||||
from apps.shark_studio.web.utils.file_utils import get_configs_path
|
||||
|
||||
"""
|
||||
The global objects include SD pipeline and config.
|
||||
Maintaining the global objects would avoid creating extra pipeline objects when switching modes.
|
||||
Also we could avoid memory leak when switching models by clearing the cache.
|
||||
"""
|
||||
def view_json_file(file_path):
|
||||
content = ""
|
||||
with open(file_path, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
def _init():
|
||||
global _sd_obj
|
||||
@@ -95,16 +89,6 @@ def get_device_list():
|
||||
global _devices
|
||||
return _devices
|
||||
|
||||
def get_init_config():
|
||||
global _init_config
|
||||
if os.path.exists(cmd_opts.defaults):
|
||||
_init_config = cmd_opts.defaults
|
||||
elif os.path.exists(os.path.join(get_configs_path(), cmd_opts.defaults)):
|
||||
_init_config = os.path.join(get_configs_path(), cmd_opts.defaults)
|
||||
else:
|
||||
print("Default config not found as absolute path or in configs folder. Using sdxl-turbo as default config.")
|
||||
_init_config = os.path.join(get_configs_path(), "sdxl-turbo.json")
|
||||
return _init_config
|
||||
|
||||
def get_sd_status():
|
||||
global _sd_obj
|
||||
|
||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
22
benchmarks/hf_model_benchmark.py
Normal file
22
benchmarks/hf_model_benchmark.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
from shark.parser import parser
|
||||
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Specifies name of HF model to benchmark. (For exmaple "microsoft/MiniLM-L12-H384-uncased"',
|
||||
)
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = load_args.model_name
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name, (test_input,), jit_trace=True
|
||||
)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
shark_module.benchmark_onnx(test_input)
|
||||
181
benchmarks/hf_transformer.py
Normal file
181
benchmarks/hf_transformer.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
from shark.shark_benchmark_runner import SharkBenchmarkRunner
|
||||
from shark.parser import shark_args
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from onnxruntime.transformers.benchmark import (
|
||||
run_pytorch,
|
||||
run_tensorflow,
|
||||
run_onnxruntime,
|
||||
)
|
||||
from onnxruntime.transformers.huggingface_models import MODELS
|
||||
from onnxruntime.transformers.benchmark_helper import ConfigModifier, Precision
|
||||
import os
|
||||
import psutil
|
||||
|
||||
|
||||
class OnnxFusionOptions(object):
|
||||
def __init__(self):
|
||||
self.disable_gelu = False
|
||||
self.disable_layer_norm = False
|
||||
self.disable_attention = False
|
||||
self.disable_skip_layer_norm = False
|
||||
self.disable_embed_layer_norm = False
|
||||
self.disable_bias_skip_layer_norm = False
|
||||
self.disable_bias_gelu = False
|
||||
self.enable_gelu_approximation = False
|
||||
self.use_mask_index = False
|
||||
self.no_attention_mask = False
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
class SharkHFBenchmarkRunner(SharkBenchmarkRunner):
|
||||
# SharkRunner derived class with Benchmarking capabilities.
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
input: tuple,
|
||||
dynamic: bool = False,
|
||||
device: str = None,
|
||||
jit_trace: bool = False,
|
||||
from_aot: bool = False,
|
||||
frontend: str = "torch",
|
||||
):
|
||||
self.device = device if device is not None else shark_args.device
|
||||
if self.device == "gpu":
|
||||
raise ValueError(
|
||||
"Currently GPU Benchmarking is not supported due to OOM from ORT."
|
||||
)
|
||||
self.model_name = model_name
|
||||
model = HuggingFaceLanguage(model_name)
|
||||
SharkBenchmarkRunner.__init__(
|
||||
self,
|
||||
model,
|
||||
input,
|
||||
dynamic,
|
||||
self.device,
|
||||
jit_trace,
|
||||
from_aot,
|
||||
frontend,
|
||||
)
|
||||
|
||||
def benchmark_torch(self, inputs):
|
||||
use_gpu = self.device == "gpu"
|
||||
# Set set the model's layer number to automatic.
|
||||
config_modifier = ConfigModifier(None)
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_pytorch(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
False,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
print(
|
||||
f"ONNX Pytorch-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
# TODO: Currently non-functional due to TF runtime error. There might be some issue with, initializing TF.
|
||||
def benchmark_tf(self, inputs):
|
||||
use_gpu = self.device == "gpu"
|
||||
# Set set the model's layer number to automatic.
|
||||
config_modifier = ConfigModifier(None)
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
verbose = False
|
||||
result = run_tensorflow(
|
||||
use_gpu,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
cache_dir,
|
||||
verbose,
|
||||
)
|
||||
print(
|
||||
f"ONNX TF-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
|
||||
def benchmark_onnx(self, inputs):
|
||||
if self.model_name not in MODELS:
|
||||
print(
|
||||
f"{self.model_name} is currently not supported in ORT's HF. Check \
|
||||
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/huggingface_models.py \
|
||||
for currently supported models. Exiting benchmark ONNX."
|
||||
)
|
||||
return
|
||||
use_gpu = self.device == "gpu"
|
||||
num_threads = psutil.cpu_count(logical=False)
|
||||
batch_sizes = [inputs.shape[0]]
|
||||
sequence_lengths = [inputs.shape[-1]]
|
||||
cache_dir = os.path.join(".", "cache_models")
|
||||
onnx_dir = os.path.join(".", "onnx_models")
|
||||
verbose = False
|
||||
input_counts = [1]
|
||||
optimize_onnx = True
|
||||
validate_onnx = False
|
||||
disable_ort_io_binding = False
|
||||
use_raw_attention_mask = True
|
||||
model_fusion_statistics = {}
|
||||
overwrite = False
|
||||
model_source = "pt" # Either "pt" or "tf"
|
||||
provider = None
|
||||
config_modifier = ConfigModifier(None)
|
||||
onnx_args = OnnxFusionOptions()
|
||||
result = run_onnxruntime(
|
||||
use_gpu,
|
||||
provider,
|
||||
[self.model_name],
|
||||
None,
|
||||
config_modifier,
|
||||
Precision.FLOAT32,
|
||||
num_threads,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
shark_args.num_iterations,
|
||||
input_counts,
|
||||
optimize_onnx,
|
||||
validate_onnx,
|
||||
cache_dir,
|
||||
onnx_dir,
|
||||
verbose,
|
||||
overwrite,
|
||||
disable_ort_io_binding,
|
||||
use_raw_attention_mask,
|
||||
model_fusion_statistics,
|
||||
model_source,
|
||||
onnx_args,
|
||||
)
|
||||
print(
|
||||
f"ONNX ORT-benchmark:{result[0]['QPS']} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
231
benchmarks/tests/test_benchmark.py
Normal file
231
benchmarks/tests/test_benchmark.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.iree_utils._common import check_device_drivers
|
||||
|
||||
import torch
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import torchvision.models as models
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
BertTokenizer,
|
||||
TFBertModel,
|
||||
)
|
||||
import importlib
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
torch.manual_seed(0)
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
|
||||
##################### Tensorflow Hugging Face LM Models ###################################
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
tf_bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class TFHuggingFaceLanguage(tf.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super(TFHuggingFaceLanguage, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(hf_model_name, from_pt=True)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=tf_bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
def get_TFhf_model(name):
|
||||
model = TFHuggingFaceLanguage(name)
|
||||
tokenizer = BertTokenizer.from_pretrained(name)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
actual_out = model.forward(*test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
##################### Hugging Face LM Models ###################################
|
||||
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self, hf_model_name):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
hf_model_name, # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
def get_hf_model(name):
|
||||
model = HuggingFaceLanguage(name)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
##################### Torch Vision Models ###################################
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
def get_vision_model(torch_model):
|
||||
model = VisionModule(torch_model)
|
||||
# TODO: Currently the test input is set to (1,128)
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
return model, test_input, actual_out
|
||||
|
||||
|
||||
############################# Benchmark Tests ####################################
|
||||
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"cuda",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "cuda", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"vulkan",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("vulkan"),
|
||||
reason="vulkaninfo not found, install from https://github.com/KhronosGroup/MoltenVK/releases",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_bench_minilm_torch(dynamic, device):
|
||||
model, test_input, act_out = get_hf_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
(test_input,),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all((test_input,))
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("iree.tools") is None,
|
||||
reason="Cannot find tools to import TF",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_bench_distilbert(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("distilbert-base-uncased")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="XLM Roberta too large to test.")
|
||||
@pytest_benchmark_param
|
||||
def test_bench_xlm_roberta(dynamic, device):
|
||||
model, test_input, act_out = get_TFhf_model("xlm-roberta-base")
|
||||
shark_module = SharkInference(
|
||||
model,
|
||||
test_input,
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
try:
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
45
benchmarks/tests/test_hf_benchmark.py
Normal file
45
benchmarks/tests/test_hf_benchmark.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
from benchmarks.hf_transformer import SharkHFBenchmarkRunner
|
||||
import importlib
|
||||
import pytest
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
############################# HF Benchmark Tests ####################################
|
||||
|
||||
# Test running benchmark module without failing.
|
||||
pytest_benchmark_param = pytest.mark.parametrize(
|
||||
("dynamic", "device"),
|
||||
[
|
||||
pytest.param(False, "cpu"),
|
||||
# TODO: Language models are failing for dynamic case..
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("onnxruntime") is None,
|
||||
reason="Cannot find ONNXRUNTIME.",
|
||||
)
|
||||
@pytest_benchmark_param
|
||||
def test_HFbench_minilm_torch(dynamic, device):
|
||||
model_name = "bert-base-uncased"
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
try:
|
||||
shark_module = SharkHFBenchmarkRunner(
|
||||
model_name,
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
dynamic=dynamic,
|
||||
device=device,
|
||||
)
|
||||
shark_module.benchmark_c()
|
||||
shark_module.benchmark_python((test_input,))
|
||||
shark_module.benchmark_torch(test_input)
|
||||
shark_module.benchmark_onnx(test_input)
|
||||
# If becnhmarking succesful, assert success/True.
|
||||
assert True
|
||||
except Exception as e:
|
||||
# If anything happen during benchmarking, assert False/failure.
|
||||
assert False
|
||||
3
cpp/.gitignore
vendored
Normal file
3
cpp/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*.mlir
|
||||
*.vmfb
|
||||
*.ini
|
||||
52
cpp/CMakeLists.txt
Normal file
52
cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
cmake_minimum_required(VERSION 3.21...3.23)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Project configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
project(iree-samples C CXX)
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Core project dependency
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
message(STATUS "Fetching core IREE repo (this may take a few minutes)...")
|
||||
# Note: for log output, set -DFETCHCONTENT_QUIET=OFF,
|
||||
# see https://gitlab.kitware.com/cmake/cmake/-/issues/18238#note_440475
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/srt.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
GIT_PROGRESS ON
|
||||
USES_TERMINAL_DOWNLOAD ON
|
||||
)
|
||||
|
||||
# Extend module path to find MLIR CMake modules.
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
# Disable core project features not needed for these out of tree samples.
|
||||
set(IREE_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(IREE_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_MakeAvailable(iree)
|
||||
FetchContent_GetProperties(iree SOURCE_DIR IREE_SOURCE_DIR)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Individual samples
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
add_subdirectory(vulkan_gui)
|
||||
82
cpp/README.md
Normal file
82
cpp/README.md
Normal file
@@ -0,0 +1,82 @@
|
||||
# SHARK C/C++ Samples
|
||||
|
||||
These C/C++ samples can be built using CMake. The samples depend on the main
|
||||
SHARK-Runtime project's C/C++ sources, including both the runtime and the compiler.
|
||||
|
||||
Individual samples may require additional dependencies. Watch CMake's output
|
||||
for information about which you are missing for individual samples.
|
||||
|
||||
On Windows we recommend using https://github.com/microsoft/vcpkg to download packages for
|
||||
your system. The general setup flow looks like
|
||||
|
||||
*Install and activate SHARK*
|
||||
|
||||
```bash
|
||||
source shark.venv/bin/activate #follow main repo instructions to setup your venv
|
||||
```
|
||||
|
||||
*Install Dependencies*
|
||||
|
||||
```bash
|
||||
vcpkg install [library] --triplet [your platform]
|
||||
vcpkg integrate install
|
||||
|
||||
# Then pass `-DCMAKE_TOOLCHAIN_FILE=[check logs for path]` when configuring CMake
|
||||
```
|
||||
|
||||
In Ubuntu Linux you can install
|
||||
|
||||
```bash
|
||||
sudo apt install libsdl2-dev
|
||||
```
|
||||
|
||||
*Build*
|
||||
```bash
|
||||
cd cpp
|
||||
cmake -GNinja -B build/
|
||||
cmake --build build/
|
||||
```
|
||||
|
||||
*Prepare the model*
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
|
||||
```
|
||||
*Prepare the input*
|
||||
|
||||
```bash
|
||||
python save_img.py
|
||||
```
|
||||
Note that this requires tensorflow, e.g.
|
||||
```bash
|
||||
python -m pip install tensorflow
|
||||
```
|
||||
|
||||
*Run the vulkan_gui*
|
||||
```bash
|
||||
./build/vulkan_gui/iree-samples-resnet-vulkan-gui
|
||||
```
|
||||
|
||||
## Other models
|
||||
A tool for benchmarking other models is built and can be invoked with a command like the following
|
||||
```bash
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=path/to/.vmfb --function_input=...
|
||||
```
|
||||
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
|
||||
```bash
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
|
||||
```
|
||||
VAE and Autoencoder are also available
|
||||
```bash
|
||||
# VAE
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
|
||||
|
||||
# CLIP Autoencoder
|
||||
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
|
||||
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
|
||||
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
|
||||
```
|
||||
BIN
cpp/dog_imagenet.jpg
Normal file
BIN
cpp/dog_imagenet.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
18
cpp/save_img.py
Normal file
18
cpp/save_img.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
def load_and_preprocess_image(fname: str):
|
||||
image = tf.io.read_file(fname)
|
||||
image = tf.image.decode_image(image, channels=3)
|
||||
image = tf.image.resize(image, (224, 224))
|
||||
image = image[tf.newaxis, :]
|
||||
# preprocessing pipeline
|
||||
input_tensor = tf.keras.applications.resnet50.preprocess_input(image)
|
||||
return input_tensor
|
||||
|
||||
|
||||
data = load_and_preprocess_image("dog_imagenet.jpg").numpy()
|
||||
|
||||
data.tofile("dog.bin")
|
||||
84
cpp/vision_inference/CMakeLists.txt
Normal file
84
cpp/vision_inference/CMakeLists.txt
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_LLVM_CPU OR
|
||||
NOT IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
|
||||
message(STATUS "Missing LLVM backend and/or embeddded elf loader, skipping vision_inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install stb
|
||||
# tested with version 2021-09-10
|
||||
find_package(Stb)
|
||||
if(NOT Stb_FOUND)
|
||||
message(STATUS "Could not find Stb, skipping vision inference sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Compile mnist.mlir to mnist.vmfb.
|
||||
set(_COMPILE_TOOL_EXECUTABLE $<TARGET_FILE:iree-compile>)
|
||||
set(_COMPILE_ARGS)
|
||||
list(APPEND _COMPILE_ARGS "--iree-input-type=auto")
|
||||
list(APPEND _COMPILE_ARGS "--iree-hal-target-backends=llvm-cpu")
|
||||
list(APPEND _COMPILE_ARGS "${IREE_SOURCE_DIR}/samples/models/mnist.mlir")
|
||||
list(APPEND _COMPILE_ARGS "-o")
|
||||
list(APPEND _COMPILE_ARGS "mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
COMMAND ${_COMPILE_TOOL_EXECUTABLE} ${_COMPILE_ARGS}
|
||||
DEPENDS ${_COMPILE_TOOL_EXECUTABLE} "${IREE_SOURCE_DIR}/samples/models/mnist.mlir"
|
||||
)
|
||||
# Embed mnist.vmfb into a C file as mnist_bytecode_module_c.[h/c]
|
||||
set(_EMBED_DATA_EXECUTABLE $<TARGET_FILE:generate_embed_data>)
|
||||
set(_EMBED_ARGS)
|
||||
list(APPEND _EMBED_ARGS "--output_header=mnist_bytecode_module_c.h")
|
||||
list(APPEND _EMBED_ARGS "--output_impl=mnist_bytecode_module_c.c")
|
||||
list(APPEND _EMBED_ARGS "--identifier=iree_samples_vision_inference_mnist_bytecode_module")
|
||||
list(APPEND _EMBED_ARGS "--flatten")
|
||||
list(APPEND _EMBED_ARGS "${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb")
|
||||
add_custom_command(
|
||||
OUTPUT "mnist_bytecode_module_c.h" "mnist_bytecode_module_c.c"
|
||||
COMMAND ${_EMBED_DATA_EXECUTABLE} ${_EMBED_ARGS}
|
||||
DEPENDS ${_EMBED_DATA_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/mnist.vmfb
|
||||
)
|
||||
# Define a library target for mnist_bytecode_module_c.
|
||||
add_library(iree_samples_vision_inference_mnist_bytecode_module_c OBJECT)
|
||||
target_sources(iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
PRIVATE
|
||||
mnist_bytecode_module_c.h
|
||||
mnist_bytecode_module_c.c
|
||||
)
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "iree-run-mnist-module")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
"image_util.h"
|
||||
"image_util.c"
|
||||
"iree-run-mnist-module.c"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "iree-run-mnist-module")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_include_directories(${_NAME} PRIVATE
|
||||
${Stb_INCLUDE_DIR}
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
iree_base_base
|
||||
iree_base_tracing
|
||||
iree_hal_hal
|
||||
iree_runtime_runtime
|
||||
iree_samples_vision_inference_mnist_bytecode_module_c
|
||||
)
|
||||
|
||||
# Define a target that copies the test image into the build directory.
|
||||
add_custom_target(iree_samples_vision_inference_test_image
|
||||
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/mnist_test.png" "${CMAKE_CURRENT_BINARY_DIR}/mnist_test.png")
|
||||
add_dependencies(${_NAME} iree_samples_vision_inference_test_image)
|
||||
|
||||
message(STATUS "Configured vision_inference sample successfully")
|
||||
8
cpp/vision_inference/README.md
Normal file
8
cpp/vision_inference/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Vision Inference Sample (C code)
|
||||
|
||||
This sample demonstrates how to run a MNIST handwritten digit detection vision
|
||||
model on an image using IREE's C API.
|
||||
|
||||
A similar sample is implemented using a Python script and IREE's command line
|
||||
tools over in the primary iree repository at
|
||||
https://github.com/iree-org/iree/tree/main/samples/vision_inference
|
||||
224
cpp/vision_inference/image_util.c
Normal file
224
cpp/vision_inference/image_util.c
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#include "image_util.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/tracing.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t buffer_length,
|
||||
const float* input_range, iree_host_size_t range_length,
|
||||
float* out_buffer) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
if (range_length != 2) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"range defined as 2-element [min, max] array.");
|
||||
}
|
||||
float input_scale = fabsf(input_range[1] - input_range[0]) / 2.0f;
|
||||
float input_offset = (input_range[0] + input_range[1]) / 2.0f;
|
||||
const float kUint8Mean = 127.5f;
|
||||
for (int i = 0; i < buffer_length; ++i) {
|
||||
out_buffer[i] =
|
||||
(((float)(pixel_data[i])) - kUint8Mean) / kUint8Mean * input_scale +
|
||||
input_offset;
|
||||
}
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data_impl(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
int img_dims[3];
|
||||
if (stbi_info(filename.data, img_dims, &(img_dims[1]), &(img_dims[2])) == 0) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
if (!(element_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_SINT_8 ||
|
||||
element_type == IREE_HAL_ELEMENT_TYPE_UINT_8)) {
|
||||
char element_type_str[16];
|
||||
IREE_RETURN_IF_ERROR(iree_hal_format_element_type(
|
||||
element_type, sizeof(element_type_str), element_type_str, NULL));
|
||||
return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
|
||||
"element type %s not supported", element_type_str);
|
||||
}
|
||||
switch (shape_rank) {
|
||||
case 2: { // Assume tensor <height x width>
|
||||
if (img_dims[2] != 1 || (shape[0] != img_dims[1]) ||
|
||||
(shape[1] != img_dims[0])) {
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim "x%" PRIdim, img_dims[0],
|
||||
img_dims[1], img_dims[2], shape[1], shape[0]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 3: { // Assume tensor <height x width x channel>
|
||||
if (shape[0] != img_dims[1] || shape[1] != img_dims[0] ||
|
||||
shape[2] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[1],
|
||||
shape[0], shape[2]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 4: { // Assume tensor <batch x height x width x channel>
|
||||
if (shape[1] != img_dims[1] || shape[2] != img_dims[0] ||
|
||||
shape[3] != img_dims[2]) {
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"image size: %dx%dx%d, expected: %" PRIdim
|
||||
"x%" PRIdim "x%" PRIdim,
|
||||
img_dims[0], img_dims[1], img_dims[2], shape[2],
|
||||
shape[1], shape[3]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return iree_make_status(
|
||||
IREE_STATUS_INVALID_ARGUMENT,
|
||||
"Input buffer shape rank %" PRIhsz " not supported", shape_rank);
|
||||
}
|
||||
// Drop the alpha channel if present.
|
||||
int req_ch = (img_dims[2] >= 3) ? 3 : 0;
|
||||
*out_pixel_data = stbi_load(filename.data, img_dims, &(img_dims[1]),
|
||||
&(img_dims[2]), req_ch);
|
||||
if (*out_pixel_data == NULL) {
|
||||
return iree_make_status(IREE_STATUS_NOT_FOUND, "can't load image %.*s",
|
||||
(int)filename.size, filename.data);
|
||||
}
|
||||
*out_buffer_length =
|
||||
img_dims[0] * img_dims[1] * (img_dims[2] > 3 ? 3 : img_dims[2]);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
iree_status_t result = iree_tools_utils_load_pixel_data_impl(
|
||||
filename, shape, shape_rank, element_type, out_pixel_data,
|
||||
out_buffer_length);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_SINT_8 &&
|
||||
element_type != IREE_HAL_ELEMENT_TYPE_UINT_8) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be i8 or u8");
|
||||
}
|
||||
|
||||
iree_status_t result;
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length;
|
||||
result = iree_tools_utils_load_pixel_data(
|
||||
filename, shape, shape_rank, element_type, &pixel_data, &buffer_length);
|
||||
if (iree_status_is_ok(result)) {
|
||||
iree_host_size_t element_byte =
|
||||
iree_hal_element_dense_byte_count(element_type);
|
||||
// SINT_8 and UINT_8 perform direct buffer wrap.
|
||||
result = iree_hal_buffer_view_allocate_buffer(
|
||||
allocator, shape_rank, shape, element_type,
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL,
|
||||
.access = IREE_HAL_MEMORY_ACCESS_READ,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER,
|
||||
},
|
||||
iree_make_const_byte_span(pixel_data, element_byte * buffer_length),
|
||||
out_buffer_view);
|
||||
}
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return result;
|
||||
}
|
||||
|
||||
typedef struct iree_tools_utils_buffer_view_load_params_t {
|
||||
const uint8_t* pixel_data;
|
||||
iree_host_size_t pixel_data_length;
|
||||
const float* input_range;
|
||||
iree_host_size_t input_range_length;
|
||||
} iree_tools_utils_buffer_view_load_params_t;
|
||||
static iree_status_t iree_tools_utils_buffer_view_load_image_rescaled(
|
||||
iree_hal_buffer_mapping_t* mapping, void* user_data) {
|
||||
iree_tools_utils_buffer_view_load_params_t* params =
|
||||
(iree_tools_utils_buffer_view_load_params_t*)user_data;
|
||||
return iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
params->pixel_data, params->pixel_data_length, params->input_range,
|
||||
params->input_range_length, (float*)mapping->contents.data);
|
||||
}
|
||||
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view) {
|
||||
IREE_TRACE_ZONE_BEGIN(z0);
|
||||
*out_buffer_view = NULL;
|
||||
if (element_type != IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
|
||||
"element type should be f32");
|
||||
}
|
||||
|
||||
// Classic row-major image layout.
|
||||
iree_hal_encoding_type_t encoding_type =
|
||||
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;
|
||||
|
||||
// Load pixel data from the file into a new host memory allocation (the only
|
||||
// interface stb_image provides). A real application would want to use the
|
||||
// generation callback to directly decode the image into the target mapped
|
||||
// device buffer.
|
||||
uint8_t* pixel_data = NULL;
|
||||
iree_host_size_t buffer_length = 0;
|
||||
IREE_RETURN_AND_END_ZONE_IF_ERROR(
|
||||
z0, iree_tools_utils_load_pixel_data(filename, shape, shape_rank,
|
||||
element_type, &pixel_data,
|
||||
&buffer_length));
|
||||
|
||||
iree_tools_utils_buffer_view_load_params_t params = {
|
||||
.pixel_data = pixel_data,
|
||||
.pixel_data_length = buffer_length,
|
||||
.input_range = input_range,
|
||||
.input_range_length = input_range_length,
|
||||
};
|
||||
iree_status_t status = iree_hal_buffer_view_generate_buffer(
|
||||
allocator, shape_rank, shape, element_type, encoding_type,
|
||||
(iree_hal_buffer_params_t){
|
||||
.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
|
||||
IREE_HAL_MEMORY_TYPE_HOST_VISIBLE,
|
||||
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
|
||||
IREE_HAL_BUFFER_USAGE_TRANSFER |
|
||||
IREE_HAL_BUFFER_USAGE_MAPPING,
|
||||
},
|
||||
iree_tools_utils_buffer_view_load_image_rescaled, ¶ms,
|
||||
out_buffer_view);
|
||||
|
||||
stbi_image_free(pixel_data);
|
||||
IREE_TRACE_ZONE_END(z0);
|
||||
return status;
|
||||
}
|
||||
77
cpp/vision_inference/image_util.h
Normal file
77
cpp/vision_inference/image_util.h
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#ifndef IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
#define IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/buffer_view.h"
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// Loads the image at |filename| into |out_pixel_data| and sets
|
||||
// |out_buffer_length| to its length.
|
||||
//
|
||||
// The image dimension must match the width, height, and channel in|shape|,
|
||||
// while 2 <= |shape_rank| <= 4 to match the image tensor format.
|
||||
//
|
||||
// The file must be in a format supported by stb_image.h.
|
||||
// The returned |out_pixel_data| buffer must be released by the caller.
|
||||
iree_status_t iree_tools_utils_load_pixel_data(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
uint8_t** out_pixel_data, iree_host_size_t* out_buffer_length);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
//
|
||||
// The |element_type| has to be SINT_8 or UINT_8. For FLOAT_32, use
|
||||
// |iree_tools_utils_buffer_view_from_image_rescaled| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Parse the content in an image file in |filename| into a HAL buffer view
|
||||
// |out_buffer_view|. |out_buffer_view| properties are defined by |shape|,
|
||||
// |shape_rank|, and |element_type|, while being allocated by |allocator|.
|
||||
// The value in |out_buffer_view| is rescaled with |input_range|.
|
||||
//
|
||||
// The |element_type| has to be FLOAT_32, For SINT_8 or UINT_8, use
|
||||
// |iree_tools_utils_buffer_view_from_image| instead.
|
||||
//
|
||||
// The returned |out_buffer_view| must be released by the caller.
|
||||
iree_status_t iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
const iree_string_view_t filename, const iree_hal_dim_t* shape,
|
||||
iree_host_size_t shape_rank, iree_hal_element_type_t element_type,
|
||||
iree_hal_allocator_t* allocator, const float* input_range,
|
||||
iree_host_size_t input_range_length,
|
||||
iree_hal_buffer_view_t** out_buffer_view);
|
||||
|
||||
// Normalize uint8_t |pixel_data| of the size |buffer_length| to float buffer
|
||||
// |out_buffer| with the range |input_range|.
|
||||
//
|
||||
// float32_x = (uint8_x - 127.5) / 127.5 * input_scale + input_offset, where
|
||||
// input_scale = abs(|input_range[0]| - |input_range[1]| / 2
|
||||
// input_offset = |input_range[0]| + |input_range[1]| / 2
|
||||
//
|
||||
// |out_buffer| needs to be allocated before the call.
|
||||
iree_status_t iree_tools_utils_pixel_rescaled_to_buffer(
|
||||
const uint8_t* pixel_data, iree_host_size_t pixel_count,
|
||||
const float* input_range, iree_host_size_t input_range_length,
|
||||
float* out_buffer);
|
||||
|
||||
#if __cplusplus
|
||||
}
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // IREE_SAMPLES_VISION_INFERENCE_IMAGE_UTIL_H_
|
||||
121
cpp/vision_inference/iree-run-mnist-module.c
Normal file
121
cpp/vision_inference/iree-run-mnist-module.c
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright 2021 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// This sample uses image_util to load a hand-written image as an
|
||||
// iree_hal_buffer_view_t then passes it to the bytecode module built from
|
||||
// mnist.mlir on the CPU backend with the local-task driver.
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "image_util.h"
|
||||
#include "iree/runtime/api.h"
|
||||
#include "mnist_bytecode_module_c.h"
|
||||
|
||||
iree_status_t Run(const iree_string_view_t image_path) {
|
||||
iree_runtime_instance_options_t instance_options;
|
||||
iree_runtime_instance_options_initialize(IREE_API_VERSION_LATEST,
|
||||
&instance_options);
|
||||
iree_runtime_instance_options_use_all_available_drivers(&instance_options);
|
||||
iree_runtime_instance_t* instance = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_create(
|
||||
&instance_options, iree_allocator_system(), &instance));
|
||||
|
||||
// TODO(#5724): move device selection into the compiled modules.
|
||||
iree_hal_device_t* device = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_instance_try_create_default_device(
|
||||
instance, iree_make_cstring_view("local-task"), &device));
|
||||
|
||||
// Create one session per loaded module to hold the module state.
|
||||
iree_runtime_session_options_t session_options;
|
||||
iree_runtime_session_options_initialize(&session_options);
|
||||
iree_runtime_session_t* session = NULL;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_create_with_device(
|
||||
instance, &session_options, device,
|
||||
iree_runtime_instance_host_allocator(instance), &session));
|
||||
iree_hal_device_release(device);
|
||||
|
||||
const struct iree_file_toc_t* module_file =
|
||||
iree_samples_vision_inference_mnist_bytecode_module_create();
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_session_append_bytecode_module_from_memory(
|
||||
session, iree_make_const_byte_span(module_file->data, module_file->size),
|
||||
iree_allocator_null()));
|
||||
|
||||
iree_runtime_call_t call;
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(
|
||||
session, iree_make_cstring_view("module.predict"), &call));
|
||||
|
||||
// Prepare the input hal buffer view with image_util library.
|
||||
// The input of the mmist model is single 28x28 pixel image as a
|
||||
// tensor<1x28x28x1xf32>, with pixels in [0.0, 1.0].
|
||||
iree_hal_buffer_view_t* buffer_view = NULL;
|
||||
iree_hal_dim_t buffer_shape[] = {1, 28, 28, 1};
|
||||
iree_hal_element_type_t hal_element_type = IREE_HAL_ELEMENT_TYPE_FLOAT_32;
|
||||
float input_range[2] = {0.0f, 1.0f};
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_tools_utils_buffer_view_from_image_rescaled(
|
||||
image_path, buffer_shape, IREE_ARRAYSIZE(buffer_shape),
|
||||
hal_element_type, iree_hal_device_allocator(device), input_range,
|
||||
IREE_ARRAYSIZE(input_range), &buffer_view),
|
||||
"load image");
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_inputs_push_back_buffer_view(&call, buffer_view));
|
||||
iree_hal_buffer_view_release(buffer_view);
|
||||
|
||||
IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0));
|
||||
|
||||
// Get the result buffers from the invocation.
|
||||
iree_hal_buffer_view_t* ret_buffer_view = NULL;
|
||||
IREE_RETURN_IF_ERROR(
|
||||
iree_runtime_call_outputs_pop_front_buffer_view(&call, &ret_buffer_view));
|
||||
|
||||
// Read back the results. The output of the mnist model is a 1x10 prediction
|
||||
// confidence values for each digit in [0, 9].
|
||||
float predictions[1 * 10] = {0.0f};
|
||||
IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h(
|
||||
iree_runtime_session_device(session),
|
||||
iree_hal_buffer_view_buffer(ret_buffer_view), 0, predictions,
|
||||
sizeof(predictions), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
|
||||
iree_infinite_timeout()));
|
||||
iree_hal_buffer_view_release(ret_buffer_view);
|
||||
|
||||
// Get the highest index from the output.
|
||||
float result_val = FLT_MIN;
|
||||
int result_idx = 0;
|
||||
for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(predictions); ++i) {
|
||||
if (predictions[i] > result_val) {
|
||||
result_val = predictions[i];
|
||||
result_idx = i;
|
||||
}
|
||||
}
|
||||
fprintf(stdout, "Detected number: %d\n", result_idx);
|
||||
|
||||
iree_runtime_call_deinitialize(&call);
|
||||
iree_runtime_session_release(session);
|
||||
iree_runtime_instance_release(instance);
|
||||
return iree_ok_status();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc > 2) {
|
||||
fprintf(stderr, "Usage: iree-run-mnist-module <image file>\n");
|
||||
return -1;
|
||||
}
|
||||
iree_string_view_t image_path;
|
||||
if (argc == 1) {
|
||||
image_path = iree_make_cstring_view("mnist_test.png");
|
||||
} else {
|
||||
image_path = iree_make_cstring_view(argv[1]);
|
||||
}
|
||||
iree_status_t result = Run(image_path);
|
||||
if (!iree_status_is_ok(result)) {
|
||||
iree_status_fprint(stderr, result);
|
||||
iree_status_ignore(result);
|
||||
return -1;
|
||||
}
|
||||
iree_status_ignore(result);
|
||||
return 0;
|
||||
}
|
||||
BIN
cpp/vision_inference/mnist_test.png
Normal file
BIN
cpp/vision_inference/mnist_test.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 261 B |
116
cpp/vulkan_gui/CMakeLists.txt
Normal file
116
cpp/vulkan_gui/CMakeLists.txt
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2022 The IREE Authors
|
||||
#
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
if(NOT IREE_TARGET_BACKEND_VULKAN_SPIRV OR
|
||||
NOT IREE_HAL_DRIVER_VULKAN)
|
||||
message(STATUS "Missing Vulkan backend and/or driver, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# This target statically links against Vulkan.
|
||||
# One way to achieve this is by installing the Vulkan SDK from
|
||||
# https://vulkan.lunarg.com/.
|
||||
include(FindVulkan)
|
||||
if(NOT Vulkan_FOUND)
|
||||
message(STATUS "Could not find Vulkan, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# vcpkg install sdl2[vulkan]
|
||||
# tested with versions 2.0.14#4 - 2.0.22#1
|
||||
find_package(SDL2)
|
||||
if(NOT SDL2_FOUND)
|
||||
message(STATUS "Could not find SDL2, skipping vulkan_gui sample")
|
||||
return()
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
imgui
|
||||
GIT_REPOSITORY https://github.com/ocornut/imgui
|
||||
GIT_TAG master
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(imgui)
|
||||
|
||||
# Dear ImGui
|
||||
set(IMGUI_DIR ${CMAKE_BINARY_DIR}/_deps/imgui-src)
|
||||
message("Looking for Imgui in ${IMGUI_DIR}")
|
||||
include_directories(${IMGUI_DIR} ${IMGUI_DIR}/backends ..)
|
||||
|
||||
|
||||
function(iree_vulkan_sample)
|
||||
|
||||
cmake_parse_arguments(
|
||||
_RULE
|
||||
""
|
||||
"NAME"
|
||||
"SRCS"
|
||||
${ARGN}
|
||||
)
|
||||
|
||||
|
||||
# Define the sample executable.
|
||||
set(_NAME "${_RULE_NAME}")
|
||||
set(SRCS "${_RULE_SRCS}")
|
||||
add_executable(${_NAME} "")
|
||||
target_sources(${_NAME}
|
||||
PRIVATE
|
||||
${SRCS}
|
||||
"${IMGUI_DIR}/backends/imgui_impl_sdl.cpp"
|
||||
"${IMGUI_DIR}/backends/imgui_impl_vulkan.cpp"
|
||||
"${IMGUI_DIR}/imgui.cpp"
|
||||
"${IMGUI_DIR}/imgui_draw.cpp"
|
||||
"${IMGUI_DIR}/imgui_demo.cpp"
|
||||
"${IMGUI_DIR}/imgui_tables.cpp"
|
||||
"${IMGUI_DIR}/imgui_widgets.cpp"
|
||||
)
|
||||
set_target_properties(${_NAME} PROPERTIES OUTPUT_NAME "${_NAME}")
|
||||
target_include_directories(${_NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}>
|
||||
)
|
||||
target_link_libraries(${_NAME}
|
||||
SDL2::SDL2
|
||||
Vulkan::Vulkan
|
||||
iree_runtime_runtime
|
||||
iree_base_internal_main
|
||||
iree_hal_drivers_vulkan_registration_registration
|
||||
iree_modules_hal_hal
|
||||
iree_vm_vm
|
||||
iree_vm_bytecode_module
|
||||
iree_vm_cc
|
||||
iree_tooling_vm_util_cc
|
||||
iree_tooling_context_util
|
||||
)
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
|
||||
set(_GUI_LINKOPTS "-SUBSYSTEM:CONSOLE")
|
||||
else()
|
||||
set(_GUI_LINKOPTS "")
|
||||
endif()
|
||||
|
||||
target_link_options(${_NAME}
|
||||
PRIVATE
|
||||
${_GUI_LINKOPTS}
|
||||
)
|
||||
endfunction()
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-samples-resnet-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_resnet_inference_gui.cc
|
||||
)
|
||||
|
||||
iree_vulkan_sample(
|
||||
NAME
|
||||
iree-vulkan-gui
|
||||
|
||||
SRCS
|
||||
vulkan_inference_gui.cc
|
||||
)
|
||||
|
||||
message(STATUS "Configured vulkan_gui sample successfully")
|
||||
4
cpp/vulkan_gui/simple_mul.mlir
Normal file
4
cpp/vulkan_gui/simple_mul.mlir
Normal file
@@ -0,0 +1,4 @@
|
||||
func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "arith.mulf"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
BIN
cpp/vulkan_gui/snail_imagenet.jpg
Normal file
BIN
cpp/vulkan_gui/snail_imagenet.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
7897
cpp/vulkan_gui/stb_image.h
Normal file
7897
cpp/vulkan_gui/stb_image.h
Normal file
File diff suppressed because it is too large
Load Diff
957
cpp/vulkan_gui/vulkan_inference_gui.cc
Normal file
957
cpp/vulkan_gui/vulkan_inference_gui.cc
Normal file
@@ -0,0 +1,957 @@
|
||||
// Copyright 2019 The IREE Authors
|
||||
//
|
||||
// Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
// Vulkan Graphics + IREE API Integration Sample.
|
||||
|
||||
#include <SDL.h>
|
||||
#include <SDL_vulkan.h>
|
||||
#include <imgui.h>
|
||||
#include <imgui_impl_sdl.h>
|
||||
#include <imgui_impl_vulkan.h>
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <array>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "iree/hal/drivers/vulkan/api.h"
|
||||
|
||||
// IREE's C API:
|
||||
#include "iree/base/api.h"
|
||||
#include "iree/hal/api.h"
|
||||
#include "iree/hal/drivers/vulkan/registration/driver_module.h"
|
||||
#include "iree/modules/hal/module.h"
|
||||
#include "iree/vm/api.h"
|
||||
#include "iree/vm/bytecode_module.h"
|
||||
#include "iree/vm/ref_cc.h"
|
||||
|
||||
// iree-run-module
|
||||
#include "iree/base/internal/flags.h"
|
||||
#include "iree/base/status_cc.h"
|
||||
#include "iree/base/tracing.h"
|
||||
#include "iree/modules/hal/types.h"
|
||||
#include "iree/tooling/comparison.h"
|
||||
#include "iree/tooling/context_util.h"
|
||||
#include "iree/tooling/vm_util_cc.h"
|
||||
|
||||
// Other dependencies (helpers, etc.)
|
||||
#include "iree/base/internal/main.h"
|
||||
|
||||
#define IMGUI_UNLIMITED_FRAME_RATE
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
IREE_FLAG(string, entry_function, "",
|
||||
"Name of a function contained in the module specified by module_file "
|
||||
"to run.");
|
||||
|
||||
// TODO(benvanik): move --function_input= flag into a util.
|
||||
static iree_status_t parse_function_io(iree_string_view_t flag_name,
|
||||
void* storage,
|
||||
iree_string_view_t value) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
list->push_back(std::string(value.data, value.size));
|
||||
return iree_ok_status();
|
||||
}
|
||||
static void print_function_io(iree_string_view_t flag_name, void* storage,
|
||||
FILE* file) {
|
||||
auto* list = (std::vector<std::string>*)storage;
|
||||
if (list->empty()) {
|
||||
fprintf(file, "# --%.*s=\n", (int)flag_name.size, flag_name.data);
|
||||
} else {
|
||||
for (size_t i = 0; i < list->size(); ++i) {
|
||||
fprintf(file, "--%.*s=\"%s\"\n", (int)flag_name.size, flag_name.data,
|
||||
list->at(i).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
static std::vector<std::string> FLAG_function_inputs;
|
||||
IREE_FLAG_CALLBACK(
|
||||
parse_function_io, print_function_io, &FLAG_function_inputs, function_input,
|
||||
"An input (a) value or (b) buffer of the format:\n"
|
||||
" (a) scalar value\n"
|
||||
" value\n"
|
||||
" e.g.: --function_input=\"3.14\"\n"
|
||||
" (b) buffer:\n"
|
||||
" [shape]xtype=[value]\n"
|
||||
" e.g.: --function_input=\"2x2xi32=1 2 3 4\"\n"
|
||||
"Optionally, brackets may be used to separate the element values:\n"
|
||||
" 2x2xi32=[[1 2][3 4]]\n"
|
||||
"Raw binary files can be read to provide buffer contents:\n"
|
||||
" 2x2xi32=@some/file.bin\n"
|
||||
"numpy npy files (from numpy.save) can be read to provide 1+ values:\n"
|
||||
" @some.npy\n"
|
||||
"Each occurrence of the flag indicates an input in the order they were\n"
|
||||
"specified on the command line.");
|
||||
|
||||
typedef struct iree_file_toc_t {
|
||||
const char* name; // the file's original name
|
||||
char* data; // beginning of the file
|
||||
size_t size; // length of the file
|
||||
} iree_file_toc_t;
|
||||
|
||||
bool load_file(const char* filename, char** pOut, size_t* pSize)
|
||||
{
|
||||
FILE* f = fopen(filename, "rb");
|
||||
if (f == NULL)
|
||||
{
|
||||
fprintf(stderr, "Can't open %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
|
||||
fseek(f, 0L, SEEK_END);
|
||||
*pSize = ftell(f);
|
||||
fseek(f, 0L, SEEK_SET);
|
||||
|
||||
*pOut = (char*)malloc(*pSize);
|
||||
|
||||
size_t size = fread(*pOut, *pSize, 1, f);
|
||||
|
||||
fclose(f);
|
||||
|
||||
return size != 0;
|
||||
}
|
||||
|
||||
static VkAllocationCallbacks* g_Allocator = NULL;
|
||||
static VkInstance g_Instance = VK_NULL_HANDLE;
|
||||
static VkPhysicalDevice g_PhysicalDevice = VK_NULL_HANDLE;
|
||||
static VkDevice g_Device = VK_NULL_HANDLE;
|
||||
static uint32_t g_QueueFamily = (uint32_t)-1;
|
||||
static VkQueue g_Queue = VK_NULL_HANDLE;
|
||||
static VkPipelineCache g_PipelineCache = VK_NULL_HANDLE;
|
||||
static VkDescriptorPool g_DescriptorPool = VK_NULL_HANDLE;
|
||||
|
||||
static ImGui_ImplVulkanH_Window g_MainWindowData;
|
||||
static uint32_t g_MinImageCount = 2;
|
||||
static bool g_SwapChainRebuild = false;
|
||||
static int g_SwapChainResizeWidth = 0;
|
||||
static int g_SwapChainResizeHeight = 0;
|
||||
|
||||
static void check_vk_result(VkResult err) {
|
||||
if (err == 0) return;
|
||||
fprintf(stderr, "VkResult: %d\n", err);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan layers used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeLayers(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> layers(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
layers.size(), &required_count,
|
||||
layers.data());
|
||||
return layers;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |extensibility_set| and |features|.
|
||||
std::vector<const char*> GetIreeExtensions(
|
||||
iree_hal_vulkan_extensibility_set_t extensibility_set,
|
||||
iree_hal_vulkan_features_t features) {
|
||||
iree_host_size_t required_count;
|
||||
iree_hal_vulkan_query_extensibility_set(
|
||||
features, extensibility_set, /*string_capacity=*/0, &required_count,
|
||||
/*out_string_values=*/NULL);
|
||||
std::vector<const char*> extensions(required_count);
|
||||
iree_hal_vulkan_query_extensibility_set(features, extensibility_set,
|
||||
extensions.size(), &required_count,
|
||||
extensions.data());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
// Returns the names of the Vulkan extensions used for the given IREE
|
||||
// |vulkan_features|.
|
||||
std::vector<const char*> GetDeviceExtensions(
|
||||
VkPhysicalDevice physical_device,
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
uint32_t extension_count = 0;
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, nullptr));
|
||||
std::vector<VkExtensionProperties> extension_properties(extension_count);
|
||||
check_vk_result(vkEnumerateDeviceExtensionProperties(
|
||||
physical_device, nullptr, &extension_count, extension_properties.data()));
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert("VK_KHR_swapchain");
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
for (int i = 0; i < iree_optional_extensions.size(); ++i) {
|
||||
const char* optional_extension = iree_optional_extensions[i];
|
||||
for (int j = 0; j < extension_count; ++j) {
|
||||
if (strcmp(optional_extension, extension_properties[j].extensionName) ==
|
||||
0) {
|
||||
ext_set.insert(optional_extension);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceLayers(
|
||||
iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Query the layers that IREE wants / needs.
|
||||
std::vector<const char*> required_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, vulkan_features);
|
||||
std::vector<const char*> optional_layers = GetIreeLayers(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, vulkan_features);
|
||||
|
||||
// Query the layers that are available on the Vulkan ICD.
|
||||
uint32_t layer_property_count = 0;
|
||||
check_vk_result(
|
||||
vkEnumerateInstanceLayerProperties(&layer_property_count, NULL));
|
||||
std::vector<VkLayerProperties> layer_properties(layer_property_count);
|
||||
check_vk_result(vkEnumerateInstanceLayerProperties(&layer_property_count,
|
||||
layer_properties.data()));
|
||||
|
||||
// Match between optional/required and available layers.
|
||||
std::vector<const char*> layers;
|
||||
for (const char* layer_name : required_layers) {
|
||||
bool found = false;
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
found = true;
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
fprintf(stderr, "Required layer %s not available\n", layer_name);
|
||||
abort();
|
||||
}
|
||||
}
|
||||
for (const char* layer_name : optional_layers) {
|
||||
for (const auto& layer_property : layer_properties) {
|
||||
if (std::strcmp(layer_name, layer_property.layerName) == 0) {
|
||||
layers.push_back(layer_name);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layers;
|
||||
}
|
||||
|
||||
std::vector<const char*> GetInstanceExtensions(
|
||||
SDL_Window* window, iree_hal_vulkan_features_t vulkan_features) {
|
||||
// Ask SDL for its list of required instance extensions.
|
||||
uint32_t sdl_extensions_count = 0;
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count, NULL);
|
||||
std::vector<const char*> sdl_extensions(sdl_extensions_count);
|
||||
SDL_Vulkan_GetInstanceExtensions(window, &sdl_extensions_count,
|
||||
sdl_extensions.data());
|
||||
|
||||
std::vector<const char*> iree_required_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED,
|
||||
vulkan_features);
|
||||
std::vector<const char*> iree_optional_extensions = GetIreeExtensions(
|
||||
IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL,
|
||||
vulkan_features);
|
||||
|
||||
// Merge extensions lists, including optional and required for simplicity.
|
||||
std::set<const char*> ext_set;
|
||||
ext_set.insert(sdl_extensions.begin(), sdl_extensions.end());
|
||||
ext_set.insert(iree_required_extensions.begin(),
|
||||
iree_required_extensions.end());
|
||||
ext_set.insert(iree_optional_extensions.begin(),
|
||||
iree_optional_extensions.end());
|
||||
std::vector<const char*> extensions(ext_set.begin(), ext_set.end());
|
||||
return extensions;
|
||||
}
|
||||
|
||||
void SetupVulkan(iree_hal_vulkan_features_t vulkan_features,
|
||||
const char** instance_layers, uint32_t instance_layers_count,
|
||||
const char** instance_extensions,
|
||||
uint32_t instance_extensions_count,
|
||||
const VkAllocationCallbacks* allocator, VkInstance* instance,
|
||||
uint32_t* queue_family_index,
|
||||
VkPhysicalDevice* physical_device, VkQueue* queue,
|
||||
VkDevice* device, VkDescriptorPool* descriptor_pool) {
|
||||
VkResult err;
|
||||
|
||||
// Create Vulkan Instance
|
||||
{
|
||||
VkInstanceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
create_info.enabledLayerCount = instance_layers_count;
|
||||
create_info.ppEnabledLayerNames = instance_layers;
|
||||
create_info.enabledExtensionCount = instance_extensions_count;
|
||||
create_info.ppEnabledExtensionNames = instance_extensions;
|
||||
err = vkCreateInstance(&create_info, allocator, instance);
|
||||
check_vk_result(err);
|
||||
}
|
||||
|
||||
// Select GPU
|
||||
{
|
||||
uint32_t gpu_count;
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, NULL);
|
||||
check_vk_result(err);
|
||||
IM_ASSERT(gpu_count > 0);
|
||||
|
||||
VkPhysicalDevice* gpus =
|
||||
(VkPhysicalDevice*)malloc(sizeof(VkPhysicalDevice) * gpu_count);
|
||||
err = vkEnumeratePhysicalDevices(*instance, &gpu_count, gpus);
|
||||
check_vk_result(err);
|
||||
|
||||
// Use the first reported GPU for simplicity.
|
||||
*physical_device = gpus[0];
|
||||
|
||||
VkPhysicalDeviceProperties properties;
|
||||
vkGetPhysicalDeviceProperties(*physical_device, &properties);
|
||||
fprintf(stdout, "Selected Vulkan device: '%s'\n", properties.deviceName);
|
||||
free(gpus);
|
||||
}
|
||||
|
||||
// Select queue family. We want a single queue with graphics and compute for
|
||||
// simplicity, but we could also discover and use separate queues for each.
|
||||
{
|
||||
uint32_t count;
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, NULL);
|
||||
VkQueueFamilyProperties* queues = (VkQueueFamilyProperties*)malloc(
|
||||
sizeof(VkQueueFamilyProperties) * count);
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(*physical_device, &count, queues);
|
||||
for (uint32_t i = 0; i < count; i++) {
|
||||
if (queues[i].queueFlags &
|
||||
(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT)) {
|
||||
*queue_family_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
free(queues);
|
||||
IM_ASSERT(*queue_family_index != (uint32_t)-1);
|
||||
}
|
||||
|
||||
// Create Logical Device (with 1 queue)
|
||||
{
|
||||
std::vector<const char*> device_extensions =
|
||||
GetDeviceExtensions(*physical_device, vulkan_features);
|
||||
const float queue_priority[] = {1.0f};
|
||||
VkDeviceQueueCreateInfo queue_info = {};
|
||||
queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
|
||||
queue_info.queueFamilyIndex = *queue_family_index;
|
||||
queue_info.queueCount = 1;
|
||||
queue_info.pQueuePriorities = queue_priority;
|
||||
VkDeviceCreateInfo create_info = {};
|
||||
create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
|
||||
create_info.queueCreateInfoCount = 1;
|
||||
create_info.pQueueCreateInfos = &queue_info;
|
||||
create_info.enabledExtensionCount =
|
||||
static_cast<uint32_t>(device_extensions.size());
|
||||
create_info.ppEnabledExtensionNames = device_extensions.data();
|
||||
|
||||
// Enable timeline semaphores.
|
||||
VkPhysicalDeviceFeatures2 features2;
|
||||
memset(&features2, 0, sizeof(features2));
|
||||
features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
||||
create_info.pNext = &features2;
|
||||
VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features;
|
||||
memset(&semaphore_features, 0, sizeof(semaphore_features));
|
||||
semaphore_features.sType =
|
||||
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES;
|
||||
semaphore_features.pNext = features2.pNext;
|
||||
features2.pNext = &semaphore_features;
|
||||
semaphore_features.timelineSemaphore = VK_TRUE;
|
||||
|
||||
err = vkCreateDevice(*physical_device, &create_info, allocator, device);
|
||||
check_vk_result(err);
|
||||
vkGetDeviceQueue(*device, *queue_family_index, 0, queue);
|
||||
}
|
||||
|
||||
// Create Descriptor Pool
|
||||
{
|
||||
VkDescriptorPoolSize pool_sizes[] = {
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1000},
|
||||
{VK_DESCRIPTOR_TYPE_INPUT_ATTACHMENT, 1000}};
|
||||
VkDescriptorPoolCreateInfo pool_info = {};
|
||||
pool_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
|
||||
pool_info.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
|
||||
pool_info.maxSets = 1000 * IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.poolSizeCount = (uint32_t)IREE_ARRAYSIZE(pool_sizes);
|
||||
pool_info.pPoolSizes = pool_sizes;
|
||||
err =
|
||||
vkCreateDescriptorPool(*device, &pool_info, allocator, descriptor_pool);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void SetupVulkanWindow(ImGui_ImplVulkanH_Window* wd,
|
||||
const VkAllocationCallbacks* allocator,
|
||||
VkInstance instance, uint32_t queue_family_index,
|
||||
VkPhysicalDevice physical_device, VkDevice device,
|
||||
VkSurfaceKHR surface, int width, int height,
|
||||
uint32_t min_image_count) {
|
||||
wd->Surface = surface;
|
||||
|
||||
// Check for WSI support
|
||||
VkBool32 res;
|
||||
vkGetPhysicalDeviceSurfaceSupportKHR(physical_device, queue_family_index,
|
||||
wd->Surface, &res);
|
||||
if (res != VK_TRUE) {
|
||||
fprintf(stderr, "Error no WSI support on physical device 0\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Select Surface Format
|
||||
const VkFormat requestSurfaceImageFormat[] = {
|
||||
VK_FORMAT_B8G8R8A8_UNORM, VK_FORMAT_R8G8B8A8_UNORM,
|
||||
VK_FORMAT_B8G8R8_UNORM, VK_FORMAT_R8G8B8_UNORM};
|
||||
const VkColorSpaceKHR requestSurfaceColorSpace =
|
||||
VK_COLORSPACE_SRGB_NONLINEAR_KHR;
|
||||
wd->SurfaceFormat = ImGui_ImplVulkanH_SelectSurfaceFormat(
|
||||
physical_device, wd->Surface, requestSurfaceImageFormat,
|
||||
(size_t)IREE_ARRAYSIZE(requestSurfaceImageFormat),
|
||||
requestSurfaceColorSpace);
|
||||
|
||||
// Select Present Mode
|
||||
#ifdef IMGUI_UNLIMITED_FRAME_RATE
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_MAILBOX_KHR,
|
||||
VK_PRESENT_MODE_IMMEDIATE_KHR,
|
||||
VK_PRESENT_MODE_FIFO_KHR};
|
||||
#else
|
||||
VkPresentModeKHR present_modes[] = {VK_PRESENT_MODE_FIFO_KHR};
|
||||
#endif
|
||||
wd->PresentMode = ImGui_ImplVulkanH_SelectPresentMode(
|
||||
physical_device, wd->Surface, &present_modes[0],
|
||||
IREE_ARRAYSIZE(present_modes));
|
||||
|
||||
// Create SwapChain, RenderPass, Framebuffer, etc.
|
||||
IM_ASSERT(min_image_count >= 2);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(instance, physical_device, device, wd,
|
||||
queue_family_index, allocator, width,
|
||||
height, min_image_count);
|
||||
|
||||
// Set clear color.
|
||||
ImVec4 clear_color = ImVec4(0.45f, 0.55f, 0.60f, 1.00f);
|
||||
memcpy(&wd->ClearValue.color.float32[0], &clear_color, 4 * sizeof(float));
|
||||
}
|
||||
|
||||
void RenderFrame(ImGui_ImplVulkanH_Window* wd, VkDevice device, VkQueue queue) {
|
||||
VkResult err;
|
||||
|
||||
VkSemaphore image_acquired_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].ImageAcquiredSemaphore;
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
err = vkAcquireNextImageKHR(device, wd->Swapchain, UINT64_MAX,
|
||||
image_acquired_semaphore, VK_NULL_HANDLE,
|
||||
&wd->FrameIndex);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkanH_Frame* fd = &wd->Frames[wd->FrameIndex];
|
||||
{
|
||||
err = vkWaitForFences(
|
||||
device, 1, &fd->Fence, VK_TRUE,
|
||||
UINT64_MAX); // wait indefinitely instead of periodically checking
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkResetFences(device, 1, &fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
err = vkResetCommandPool(device, fd->CommandPool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(fd->CommandBuffer, &info);
|
||||
check_vk_result(err);
|
||||
}
|
||||
{
|
||||
VkRenderPassBeginInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_RENDER_PASS_BEGIN_INFO;
|
||||
info.renderPass = wd->RenderPass;
|
||||
info.framebuffer = fd->Framebuffer;
|
||||
info.renderArea.extent.width = wd->Width;
|
||||
info.renderArea.extent.height = wd->Height;
|
||||
info.clearValueCount = 1;
|
||||
info.pClearValues = &wd->ClearValue;
|
||||
vkCmdBeginRenderPass(fd->CommandBuffer, &info, VK_SUBPASS_CONTENTS_INLINE);
|
||||
}
|
||||
|
||||
// Record Imgui Draw Data and draw funcs into command buffer
|
||||
ImGui_ImplVulkan_RenderDrawData(ImGui::GetDrawData(), fd->CommandBuffer);
|
||||
|
||||
// Submit command buffer
|
||||
vkCmdEndRenderPass(fd->CommandBuffer);
|
||||
{
|
||||
VkPipelineStageFlags wait_stage =
|
||||
VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
|
||||
VkSubmitInfo info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &image_acquired_semaphore;
|
||||
info.pWaitDstStageMask = &wait_stage;
|
||||
info.commandBufferCount = 1;
|
||||
info.pCommandBuffers = &fd->CommandBuffer;
|
||||
info.signalSemaphoreCount = 1;
|
||||
info.pSignalSemaphores = &render_complete_semaphore;
|
||||
|
||||
err = vkEndCommandBuffer(fd->CommandBuffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(queue, 1, &info, fd->Fence);
|
||||
check_vk_result(err);
|
||||
}
|
||||
}
|
||||
|
||||
void PresentFrame(ImGui_ImplVulkanH_Window* wd, VkQueue queue) {
|
||||
VkSemaphore render_complete_semaphore =
|
||||
wd->FrameSemaphores[wd->SemaphoreIndex].RenderCompleteSemaphore;
|
||||
VkPresentInfoKHR info = {};
|
||||
info.sType = VK_STRUCTURE_TYPE_PRESENT_INFO_KHR;
|
||||
info.waitSemaphoreCount = 1;
|
||||
info.pWaitSemaphores = &render_complete_semaphore;
|
||||
info.swapchainCount = 1;
|
||||
info.pSwapchains = &wd->Swapchain;
|
||||
info.pImageIndices = &wd->FrameIndex;
|
||||
VkResult err = vkQueuePresentKHR(queue, &info);
|
||||
check_vk_result(err);
|
||||
wd->SemaphoreIndex =
|
||||
(wd->SemaphoreIndex + 1) %
|
||||
wd->ImageCount; // Now we can use the next set of semaphores
|
||||
}
|
||||
|
||||
static void CleanupVulkan() {
|
||||
vkDestroyDescriptorPool(g_Device, g_DescriptorPool, g_Allocator);
|
||||
|
||||
vkDestroyDevice(g_Device, g_Allocator);
|
||||
vkDestroyInstance(g_Instance, g_Allocator);
|
||||
}
|
||||
|
||||
static void CleanupVulkanWindow() {
|
||||
ImGui_ImplVulkanH_DestroyWindow(g_Instance, g_Device, &g_MainWindowData,
|
||||
g_Allocator);
|
||||
}
|
||||
|
||||
namespace iree {
|
||||
|
||||
extern "C" int iree_main(int argc, char** argv) {
|
||||
|
||||
iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
|
||||
if (argc > 1) {
|
||||
// Avoid iree-run-module spinning endlessly on stdin if the user uses single
|
||||
// dashes for flags.
|
||||
printf(
|
||||
"[ERROR] unexpected positional argument (expected none)."
|
||||
" Did you use pass a flag with a single dash ('-')?"
|
||||
" Use '--' instead.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create a window.
|
||||
if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) {
|
||||
fprintf(stderr, "Failed to initialize SDL\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup window
|
||||
// clang-format off
|
||||
SDL_WindowFlags window_flags = (SDL_WindowFlags)(
|
||||
SDL_WINDOW_VULKAN | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI);
|
||||
// clang-format on
|
||||
SDL_Window* window = SDL_CreateWindow(
|
||||
"IREE Samples - Vulkan Inference GUI", SDL_WINDOWPOS_CENTERED,
|
||||
SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags);
|
||||
if (window == nullptr)
|
||||
{
|
||||
const char* sdl_err = SDL_GetError();
|
||||
fprintf(stderr, "Error, SDL_CreateWindow returned: %s\n", sdl_err);
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Setup Vulkan
|
||||
iree_hal_vulkan_features_t iree_vulkan_features =
|
||||
static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS |
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features);
|
||||
std::vector<const char*> extensions =
|
||||
GetInstanceExtensions(window, iree_vulkan_features);
|
||||
SetupVulkan(iree_vulkan_features, layers.data(),
|
||||
static_cast<uint32_t>(layers.size()), extensions.data(),
|
||||
static_cast<uint32_t>(extensions.size()), g_Allocator,
|
||||
&g_Instance, &g_QueueFamily, &g_PhysicalDevice, &g_Queue,
|
||||
&g_Device, &g_DescriptorPool);
|
||||
|
||||
// Create Window Surface
|
||||
VkSurfaceKHR surface;
|
||||
VkResult err;
|
||||
if (SDL_Vulkan_CreateSurface(window, g_Instance, &surface) == 0) {
|
||||
fprintf(stderr, "Failed to create Vulkan surface.\n");
|
||||
abort();
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create Framebuffers
|
||||
int w, h;
|
||||
SDL_GetWindowSize(window, &w, &h);
|
||||
ImGui_ImplVulkanH_Window* wd = &g_MainWindowData;
|
||||
SetupVulkanWindow(wd, g_Allocator, g_Instance, g_QueueFamily,
|
||||
g_PhysicalDevice, g_Device, surface, w, h, g_MinImageCount);
|
||||
|
||||
// Setup Dear ImGui context
|
||||
IMGUI_CHECKVERSION();
|
||||
ImGui::CreateContext();
|
||||
ImGuiIO& io = ImGui::GetIO();
|
||||
(void)io;
|
||||
|
||||
ImGui::StyleColorsDark();
|
||||
|
||||
// Setup Platform/Renderer bindings
|
||||
ImGui_ImplSDL2_InitForVulkan(window);
|
||||
ImGui_ImplVulkan_InitInfo init_info = {};
|
||||
init_info.Instance = g_Instance;
|
||||
init_info.PhysicalDevice = g_PhysicalDevice;
|
||||
init_info.Device = g_Device;
|
||||
init_info.QueueFamily = g_QueueFamily;
|
||||
init_info.Queue = g_Queue;
|
||||
init_info.PipelineCache = g_PipelineCache;
|
||||
init_info.DescriptorPool = g_DescriptorPool;
|
||||
init_info.Allocator = g_Allocator;
|
||||
init_info.MinImageCount = g_MinImageCount;
|
||||
init_info.ImageCount = wd->ImageCount;
|
||||
init_info.CheckVkResultFn = check_vk_result;
|
||||
ImGui_ImplVulkan_Init(&init_info, wd->RenderPass);
|
||||
|
||||
// Upload Fonts
|
||||
{
|
||||
// Use any command queue
|
||||
VkCommandPool command_pool = wd->Frames[wd->FrameIndex].CommandPool;
|
||||
VkCommandBuffer command_buffer = wd->Frames[wd->FrameIndex].CommandBuffer;
|
||||
|
||||
err = vkResetCommandPool(g_Device, command_pool, 0);
|
||||
check_vk_result(err);
|
||||
VkCommandBufferBeginInfo begin_info = {};
|
||||
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
begin_info.flags |= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
err = vkBeginCommandBuffer(command_buffer, &begin_info);
|
||||
check_vk_result(err);
|
||||
|
||||
ImGui_ImplVulkan_CreateFontsTexture(command_buffer);
|
||||
|
||||
VkSubmitInfo end_info = {};
|
||||
end_info.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
end_info.commandBufferCount = 1;
|
||||
end_info.pCommandBuffers = &command_buffer;
|
||||
err = vkEndCommandBuffer(command_buffer);
|
||||
check_vk_result(err);
|
||||
err = vkQueueSubmit(g_Queue, 1, &end_info, VK_NULL_HANDLE);
|
||||
check_vk_result(err);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_DestroyFontUploadObjects();
|
||||
}
|
||||
|
||||
// Demo state.
|
||||
bool show_iree_window = true;
|
||||
// --------------------------------------------------------------------------
|
||||
// Setup IREE.
|
||||
|
||||
// Check API version.
|
||||
iree_api_version_t actual_version;
|
||||
iree_status_t status =
|
||||
iree_api_version_check(IREE_API_VERSION_LATEST, &actual_version);
|
||||
if (iree_status_is_ok(status)) {
|
||||
fprintf(stdout, "IREE runtime API version: %d\n", actual_version);
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported runtime API version: %d\n", actual_version);
|
||||
abort();
|
||||
}
|
||||
|
||||
// Create a runtime Instance.
|
||||
iree_vm_instance_t* iree_instance = nullptr;
|
||||
IREE_CHECK_OK(
|
||||
iree_vm_instance_create(iree_allocator_system(), &iree_instance));
|
||||
|
||||
// Register HAL drivers and VM module types.
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_module_register(
|
||||
iree_hal_driver_registry_default()));
|
||||
IREE_CHECK_OK(iree_hal_module_register_all_types(iree_instance));
|
||||
|
||||
// Create IREE Vulkan Driver and Device, sharing our VkInstance/VkDevice.
|
||||
fprintf(stdout, "Creating Vulkan driver/device\n");
|
||||
// Load symbols from our static `vkGetInstanceProcAddr` for IREE to use.
|
||||
iree_hal_vulkan_syms_t* iree_vk_syms = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_syms_create(
|
||||
reinterpret_cast<void*>(&vkGetInstanceProcAddr), iree_allocator_system(),
|
||||
&iree_vk_syms));
|
||||
// Create the driver sharing our VkInstance.
|
||||
iree_hal_driver_t* iree_vk_driver = nullptr;
|
||||
iree_string_view_t driver_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_driver_options_t driver_options;
|
||||
driver_options.api_version = VK_API_VERSION_1_0;
|
||||
driver_options.requested_features = static_cast<iree_hal_vulkan_features_t>(
|
||||
IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS);
|
||||
IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance(
|
||||
driver_identifier, &driver_options, iree_vk_syms, g_Instance,
|
||||
iree_allocator_system(), &iree_vk_driver));
|
||||
// Create a device sharing our VkDevice and queue.
|
||||
// We could also create a separate (possibly low priority) compute queue for
|
||||
// IREE, and/or provide a dedicated transfer queue.
|
||||
iree_string_view_t device_identifier = iree_make_cstring_view("vulkan");
|
||||
iree_hal_vulkan_queue_set_t compute_queue_set;
|
||||
compute_queue_set.queue_family_index = g_QueueFamily;
|
||||
compute_queue_set.queue_indices = 1 << 0;
|
||||
iree_hal_vulkan_queue_set_t transfer_queue_set;
|
||||
transfer_queue_set.queue_indices = 0;
|
||||
iree_hal_device_t* iree_vk_device = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_vulkan_wrap_device(
|
||||
device_identifier, &driver_options.device_options, iree_vk_syms,
|
||||
g_Instance, g_PhysicalDevice, g_Device, &compute_queue_set,
|
||||
&transfer_queue_set, iree_allocator_system(), &iree_vk_device));
|
||||
// Create a HAL module using the HAL device.
|
||||
iree_vm_module_t* hal_module = nullptr;
|
||||
IREE_CHECK_OK(iree_hal_module_create(iree_instance, iree_vk_device,
|
||||
IREE_HAL_MODULE_FLAG_NONE,
|
||||
iree_allocator_system(), &hal_module));
|
||||
|
||||
|
||||
// Load bytecode module
|
||||
//iree_file_toc_t module_file_toc;
|
||||
//const char network_model[] = "resnet50_tf.vmfb";
|
||||
//fprintf(stdout, "Loading: %s\n", network_model);
|
||||
//if (load_file(network_model, &module_file_toc.data, &module_file_toc.size) == false)
|
||||
//{
|
||||
// abort();
|
||||
// return 1;
|
||||
//}
|
||||
//fprintf(stdout, "module size: %zu\n", module_file_toc.size);
|
||||
|
||||
iree_vm_module_t* bytecode_module = nullptr;
|
||||
iree_status_t module_status = iree_tooling_load_module_from_flags(
|
||||
iree_instance, iree_allocator_system(), &bytecode_module);
|
||||
if (!iree_status_is_ok(module_status))
|
||||
return -1;
|
||||
//IREE_CHECK_OK(iree_vm_bytecode_module_create(
|
||||
// iree_instance,
|
||||
// iree_const_byte_span_t{
|
||||
// reinterpret_cast<const uint8_t*>(module_file_toc.data),
|
||||
// module_file_toc.size},
|
||||
// iree_allocator_null(), iree_allocator_system(), &bytecode_module));
|
||||
//// Query for details about what is in the loaded module.
|
||||
//iree_vm_module_signature_t bytecode_module_signature =
|
||||
// iree_vm_module_signature(bytecode_module);
|
||||
//fprintf(stdout, "Module loaded, have <%" PRIhsz "> exported functions:\n",
|
||||
// bytecode_module_signature.export_function_count);
|
||||
//for (int i = 0; i < bytecode_module_signature.export_function_count; ++i) {
|
||||
// iree_vm_function_t function;
|
||||
// IREE_CHECK_OK(iree_vm_module_lookup_function_by_ordinal(
|
||||
// bytecode_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function));
|
||||
// auto function_name = iree_vm_function_name(&function);
|
||||
// auto function_signature = iree_vm_function_signature(&function);
|
||||
|
||||
// fprintf(stdout, " %d: '%.*s' with calling convention '%.*s'\n", i,
|
||||
// (int)function_name.size, function_name.data,
|
||||
// (int)function_signature.calling_convention.size,
|
||||
// function_signature.calling_convention.data);
|
||||
//}
|
||||
|
||||
// Allocate a context that will hold the module state across invocations.
|
||||
iree_vm_context_t* iree_context = nullptr;
|
||||
std::vector<iree_vm_module_t*> modules = {hal_module, bytecode_module};
|
||||
IREE_CHECK_OK(iree_vm_context_create_with_modules(
|
||||
iree_instance, IREE_VM_CONTEXT_FLAG_NONE, modules.size(), modules.data(),
|
||||
iree_allocator_system(), &iree_context));
|
||||
fprintf(stdout, "Context with modules is ready for use\n");
|
||||
|
||||
// Lookup the entry point function.
|
||||
iree_vm_function_t main_function;
|
||||
const char kMainFunctionName[] = "module.forward";
|
||||
IREE_CHECK_OK(iree_vm_context_resolve_function(
|
||||
iree_context,
|
||||
iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1},
|
||||
&main_function));
|
||||
iree_string_view_t main_function_name = iree_vm_function_name(&main_function);
|
||||
fprintf(stdout, "Resolved main function named '%.*s'\n",
|
||||
(int)main_function_name.size, main_function_name.data);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Write inputs into mappable buffers.
|
||||
iree_hal_allocator_t* allocator =
|
||||
iree_hal_device_allocator(iree_vk_device);
|
||||
//iree_hal_memory_type_t input_memory_type =
|
||||
// static_cast<iree_hal_memory_type_t>(
|
||||
// IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
|
||||
// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE);
|
||||
//iree_hal_buffer_usage_t input_buffer_usage =
|
||||
// static_cast<iree_hal_buffer_usage_t>(IREE_HAL_BUFFER_USAGE_DEFAULT);
|
||||
//iree_hal_buffer_params_t buffer_params;
|
||||
//buffer_params.type = input_memory_type;
|
||||
//buffer_params.usage = input_buffer_usage;
|
||||
//buffer_params.access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE;
|
||||
|
||||
// Wrap input buffers in buffer views.
|
||||
|
||||
vm::ref<iree_vm_list_t> inputs;
|
||||
iree_status_t input_status = ParseToVariantList(
|
||||
allocator,
|
||||
iree::span<const std::string>{FLAG_function_inputs.data(),
|
||||
FLAG_function_inputs.size()},
|
||||
iree_allocator_system(), &inputs);
|
||||
if (!iree_status_is_ok(input_status))
|
||||
return -1;
|
||||
//vm::ref<iree_vm_list_t> inputs;
|
||||
//IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, 6, iree_allocator_system(), &inputs));
|
||||
|
||||
//iree_hal_buffer_view_t* input0_buffer_view = nullptr;
|
||||
//constexpr iree_hal_dim_t input_buffer_shape[] = {1, 224, 224, 3};
|
||||
//IREE_CHECK_OK(iree_hal_buffer_view_allocate_buffer(
|
||||
// allocator,
|
||||
// /*shape_rank=*/4, /*shape=*/input_buffer_shape,
|
||||
// IREE_HAL_ELEMENT_TYPE_FLOAT_32,
|
||||
// IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
|
||||
// iree_make_const_byte_span(&input_res50, sizeof(input_res50)),
|
||||
// &input0_buffer_view));
|
||||
|
||||
//auto input0_buffer_view_ref = iree_hal_buffer_view_move_ref(input0_buffer_view);
|
||||
//IREE_CHECK_OK(iree_vm_list_push_ref_move(inputs.get(), &input0_buffer_view_ref));
|
||||
|
||||
// Prepare outputs list to accept results from the invocation.
|
||||
|
||||
vm::ref<iree_vm_list_t> outputs;
|
||||
constexpr iree_hal_dim_t kOutputCount = 1000;
|
||||
IREE_CHECK_OK(iree_vm_list_create(/*element_type=*/nullptr, kOutputCount * sizeof(float), iree_allocator_system(), &outputs));
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Main loop.
|
||||
bool done = false;
|
||||
while (!done) {
|
||||
SDL_Event event;
|
||||
|
||||
while (SDL_PollEvent(&event)) {
|
||||
if (event.type == SDL_QUIT) {
|
||||
done = true;
|
||||
}
|
||||
|
||||
ImGui_ImplSDL2_ProcessEvent(&event);
|
||||
if (event.type == SDL_QUIT) done = true;
|
||||
if (event.type == SDL_WINDOWEVENT &&
|
||||
event.window.event == SDL_WINDOWEVENT_RESIZED &&
|
||||
event.window.windowID == SDL_GetWindowID(window)) {
|
||||
g_SwapChainResizeWidth = (int)event.window.data1;
|
||||
g_SwapChainResizeHeight = (int)event.window.data2;
|
||||
g_SwapChainRebuild = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (g_SwapChainRebuild) {
|
||||
g_SwapChainRebuild = false;
|
||||
ImGui_ImplVulkan_SetMinImageCount(g_MinImageCount);
|
||||
ImGui_ImplVulkanH_CreateOrResizeWindow(
|
||||
g_Instance, g_PhysicalDevice, g_Device, &g_MainWindowData,
|
||||
g_QueueFamily, g_Allocator, g_SwapChainResizeWidth,
|
||||
g_SwapChainResizeHeight, g_MinImageCount);
|
||||
g_MainWindowData.FrameIndex = 0;
|
||||
}
|
||||
|
||||
// Start the Dear ImGui frame
|
||||
ImGui_ImplVulkan_NewFrame();
|
||||
ImGui_ImplSDL2_NewFrame(window);
|
||||
ImGui::NewFrame();
|
||||
|
||||
// Custom window.
|
||||
{
|
||||
ImGui::Begin("IREE Vulkan Integration Demo", &show_iree_window);
|
||||
|
||||
ImGui::Separator();
|
||||
|
||||
// ImGui Inputs for two input tensors.
|
||||
// Run computation whenever any of the values changes.
|
||||
static bool dirty = true;
|
||||
if (dirty) {
|
||||
|
||||
// Synchronously invoke the function.
|
||||
IREE_CHECK_OK(iree_vm_invoke(iree_context, main_function,
|
||||
IREE_VM_INVOCATION_FLAG_NONE,
|
||||
/*policy=*/nullptr, inputs.get(),
|
||||
outputs.get(), iree_allocator_system()));
|
||||
|
||||
|
||||
// we want to run continuously so we can use tools like RenderDoc, RGP, etc...
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
// Framerate counter.
|
||||
ImGui::Text("Application average %.3f ms/frame (%.1f FPS)",
|
||||
1000.0f / ImGui::GetIO().Framerate, ImGui::GetIO().Framerate);
|
||||
|
||||
ImGui::End();
|
||||
}
|
||||
|
||||
// Rendering
|
||||
ImGui::Render();
|
||||
RenderFrame(wd, g_Device, g_Queue);
|
||||
|
||||
PresentFrame(wd, g_Queue);
|
||||
}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Cleanup
|
||||
iree_vm_module_release(hal_module);
|
||||
iree_vm_module_release(bytecode_module);
|
||||
iree_vm_context_release(iree_context);
|
||||
iree_hal_device_release(iree_vk_device);
|
||||
iree_hal_allocator_release(allocator);
|
||||
iree_hal_driver_release(iree_vk_driver);
|
||||
iree_hal_vulkan_syms_release(iree_vk_syms);
|
||||
iree_vm_instance_release(iree_instance);
|
||||
|
||||
err = vkDeviceWaitIdle(g_Device);
|
||||
check_vk_result(err);
|
||||
ImGui_ImplVulkan_Shutdown();
|
||||
ImGui_ImplSDL2_Shutdown();
|
||||
ImGui::DestroyContext();
|
||||
|
||||
CleanupVulkanWindow();
|
||||
CleanupVulkan();
|
||||
|
||||
SDL_DestroyWindow(window);
|
||||
SDL_Quit();
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace iree
|
||||
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
1160
cpp/vulkan_gui/vulkan_resnet_inference_gui.cc
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,8 +8,8 @@ wheel
|
||||
|
||||
torch==2.3.0
|
||||
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@deprecated-constraints#subdirectory=models
|
||||
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
|
||||
|
||||
# SHARK Runner
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install --pre -r requirements.txt
|
||||
pip install https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_compiler-20240602.283-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240602.283/iree_runtime-20240602.283-cp311-cp311-win_amd64.whl
|
||||
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
|
||||
pip install -e .
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
28
shark/__init__.py
Normal file
28
shark/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
def shark(model, inputs, *, options):
|
||||
try:
|
||||
from shark.dynamo_backend.utils import SharkBackend
|
||||
except ImportError:
|
||||
log.exception(
|
||||
"Unable to import SHARK - High Performance Machine Learning Distribution"
|
||||
"Please install the right version of SHARK that matches the PyTorch version being used. "
|
||||
"Refer to https://github.com/nod-ai/SHARK/ for details."
|
||||
)
|
||||
raise
|
||||
return SharkBackend(model, inputs, options)
|
||||
|
||||
|
||||
def has_shark():
|
||||
try:
|
||||
importlib.import_module("shark")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
78
shark/backward_makefx.py
Normal file
78
shark/backward_makefx.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.utils import stateless
|
||||
|
||||
from torch import fx
|
||||
import tempfile
|
||||
|
||||
|
||||
class MakeFxModule:
|
||||
def __init__(self, model, inputs, labels=None, custom_inference_fn=None):
|
||||
self.model = model
|
||||
self.inputs = inputs
|
||||
self.custom_inference_fn = custom_inference_fn
|
||||
self.training_graph = None
|
||||
|
||||
# Doesn't replace the None type.
|
||||
def change_fx_graph_return_to_tuple(self, fx_g: fx.GraphModule):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
# output nodes always have one argument
|
||||
node_arg = node.args[0]
|
||||
out_nodes = []
|
||||
if isinstance(node_arg, list):
|
||||
# Don't return NoneType elements.
|
||||
for out_node in node_arg:
|
||||
if not isinstance(out_node, type(None)):
|
||||
out_nodes.append(out_node)
|
||||
# If there is a single tensor/element to be returned don't
|
||||
# a tuple for it.
|
||||
if len(out_nodes) == 1:
|
||||
node.args = out_nodes
|
||||
else:
|
||||
node.args = (tuple(out_nodes),)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
def generate_graph(self):
|
||||
fx_g = make_fx(
|
||||
self.custom_inference_fn,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
]
|
||||
),
|
||||
)(
|
||||
dict(self.model.named_parameters()),
|
||||
dict(self.model.named_buffers()),
|
||||
self.inputs,
|
||||
)
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
fx_g = self.change_fx_graph_return_to_tuple(fx_g)
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
suffix="_shark_ts", prefix="temp_ts_"
|
||||
)
|
||||
ts_g.save(temp.name)
|
||||
new_ts = torch.jit.load(temp.name)
|
||||
self.training_graph = new_ts
|
||||
0
shark/dynamo_backend/__init__.py
Normal file
0
shark/dynamo_backend/__init__.py
Normal file
154
shark/dynamo_backend/utils.py
Normal file
154
shark/dynamo_backend/utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class SharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.shark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
self._modify_fx_g()
|
||||
self.compile()
|
||||
|
||||
def _modify_fx_g(self):
|
||||
self.none_indices = _remove_nones(self.fx_g)
|
||||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
||||
|
||||
def compile(self):
|
||||
gm = make_fx(
|
||||
functionalize(self.fx_g),
|
||||
decomposition_table=default_decompositions(),
|
||||
)(*self.inputs)
|
||||
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
gm.recompile()
|
||||
strip_overloads(gm)
|
||||
ts_g = torch.jit.script(gm)
|
||||
mlir_module = torch_mlir.compile(
|
||||
ts_g, self.inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
self.shark_module = shark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.shark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
]
|
||||
|
||||
if not isinstance(np_outs, list):
|
||||
res = torch.from_numpy(np_outs)
|
||||
return res
|
||||
|
||||
result = [torch.from_numpy(x) for x in np_outs]
|
||||
for r_in in self.none_indices:
|
||||
result.insert(r_in, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
25
shark/examples/shark_dynamo/basic_examples.py
Normal file
25
shark/examples/shark_dynamo/basic_examples.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import shark
|
||||
|
||||
|
||||
def foo(x, a):
|
||||
if x.shape[0] > 3:
|
||||
return x + a
|
||||
else:
|
||||
return x + 3
|
||||
|
||||
|
||||
shark_options = {"device": "cpu"}
|
||||
compiled = torch.compile(foo, backend="shark", options=shark_options)
|
||||
|
||||
input = torch.ones(4)
|
||||
|
||||
x = compiled(input, input)
|
||||
|
||||
print(x)
|
||||
|
||||
input = torch.ones(3)
|
||||
|
||||
x = compiled(input, input)
|
||||
|
||||
print(x)
|
||||
309
shark/examples/shark_eager/dynamo_demo.ipynb
Normal file
309
shark/examples/shark_eager/dynamo_demo.ipynb
Normal file
@@ -0,0 +1,309 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# standard imports\n",
|
||||
"import torch\n",
|
||||
"from shark.iree_utils import get_iree_compiled_module"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# torch dynamo related imports\n",
|
||||
"try:\n",
|
||||
" import torchdynamo\n",
|
||||
" from torchdynamo.optimizations.backends import create_backend\n",
|
||||
" from torchdynamo.optimizations.subgraph import SubGraph\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" print(\n",
|
||||
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
|
||||
" )\n",
|
||||
" exit()\n",
|
||||
"\n",
|
||||
"# torch-mlir imports for compiling\n",
|
||||
"from torch_mlir import compile, OutputType"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends."
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def toy_example(*args):\n",
|
||||
" a, b = args\n",
|
||||
"\n",
|
||||
" x = a / (torch.abs(a) + 1)\n",
|
||||
" if b.sum() < 0:\n",
|
||||
" b = b * -1\n",
|
||||
" return x * b"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# compiler that lowers fx_graph to through MLIR\n",
|
||||
"def __torch_mlir(fx_graph, *args, **kwargs):\n",
|
||||
" assert isinstance(\n",
|
||||
" fx_graph, torch.fx.GraphModule\n",
|
||||
" ), \"Model must be an FX GraphModule.\"\n",
|
||||
"\n",
|
||||
" def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n",
|
||||
" \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n",
|
||||
"\n",
|
||||
" for node in fx_g.graph.nodes:\n",
|
||||
" if node.op == \"output\":\n",
|
||||
" assert (\n",
|
||||
" len(node.args) == 1\n",
|
||||
" ), \"Output node must have a single argument\"\n",
|
||||
" node_arg = node.args[0]\n",
|
||||
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
|
||||
" node.args = (node_arg[0],)\n",
|
||||
" fx_g.graph.lint()\n",
|
||||
" fx_g.recompile()\n",
|
||||
" return fx_g\n",
|
||||
"\n",
|
||||
" fx_graph = _unwrap_single_tuple_return(fx_graph)\n",
|
||||
" ts_graph = torch.jit.script(fx_graph)\n",
|
||||
"\n",
|
||||
" # torchdynamo does munges the args differently depending on whether you use\n",
|
||||
" # the @torchdynamo.optimize decorator or the context manager\n",
|
||||
" if isinstance(args, tuple):\n",
|
||||
" args = list(args)\n",
|
||||
" assert isinstance(args, list)\n",
|
||||
" if len(args) == 1 and isinstance(args[0], list):\n",
|
||||
" args = args[0]\n",
|
||||
"\n",
|
||||
" linalg_module = compile(\n",
|
||||
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
|
||||
" )\n",
|
||||
" callable, _ = get_iree_compiled_module(\n",
|
||||
" linalg_module, \"cuda\", func_name=\"forward\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(*inputs):\n",
|
||||
" return callable(*inputs)\n",
|
||||
"\n",
|
||||
" return forward"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 device(s).\n",
|
||||
"Device: 0\n",
|
||||
" Name: NVIDIA GeForce RTX 3080\n",
|
||||
" Compute Capability: 8.6\n",
|
||||
"[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n",
|
||||
" 0.33262825 -0.20109026 0.02102537 -0.24882983]\n",
|
||||
"[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n",
|
||||
" -0.06660341 0.24227881 0.1462235 -0.32055548]\n",
|
||||
"[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n",
|
||||
" -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n",
|
||||
"[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n",
|
||||
" 0.26199922 0.8123446 -0.01576749 0.30846444]\n",
|
||||
"[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n",
|
||||
" -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n",
|
||||
"[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n",
|
||||
" -0.7518105 -0.03078967 -0.07623022 0.38865626]\n",
|
||||
"[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n",
|
||||
" -0.20471913 0.3337415 -0.3619432 -0.35087156]\n",
|
||||
"[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n",
|
||||
" -0.58207744 0.06457533 0.18276742 0.03866556]\n",
|
||||
"[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n",
|
||||
" -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n",
|
||||
"[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n",
|
||||
" -0.00531162 0.15447603 -0.5220248 -0.09016377]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with torchdynamo.optimize(__torch_mlir):\n",
|
||||
" for _ in range(10):\n",
|
||||
" print(toy_example(torch.randn(10), torch.randn(10)))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"It can also be used through a decorator:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@create_backend\n",
|
||||
"def torch_mlir(subgraph, *args, **kwargs):\n",
|
||||
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
|
||||
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@torchdynamo.optimize(\"torch_mlir\")\n",
|
||||
"def toy_example2(*args):\n",
|
||||
" a, b = args\n",
|
||||
"\n",
|
||||
" x = a / (torch.abs(a) + 1)\n",
|
||||
" if b.sum() < 0:\n",
|
||||
" b = b * -1\n",
|
||||
" return x * b"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found 1 device(s).\n",
|
||||
"Device: 0\n",
|
||||
" Name: NVIDIA GeForce RTX 3080\n",
|
||||
" Compute Capability: 8.6\n",
|
||||
"[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n",
|
||||
" -0.6609761 -0.6418614 0.29336175 -0.01973678]\n",
|
||||
"[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n",
|
||||
" -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n",
|
||||
" -2.7621329e-01 -6.0995776e-02]\n",
|
||||
"[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n",
|
||||
" -1.0124422 0.5720256 -0.35437614 -0.20992722]\n",
|
||||
"[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n",
|
||||
" -0.46145165 -0.8776549 0.10090438 0.17463352]\n",
|
||||
"[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n",
|
||||
" -0.04456685 1.1100804 0.8105744 0.6676846 ]\n",
|
||||
"[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n",
|
||||
" -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n",
|
||||
"[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n",
|
||||
" 0.02403045 0.17344482 0.46406478 -0.00129451]\n",
|
||||
"[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n",
|
||||
" -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n",
|
||||
"[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n",
|
||||
" -0.01425194 0.4882803 0.3551982 -0.858935 ]\n",
|
||||
"[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n",
|
||||
" 0.11938014 -0.01122053 0.39294165 -0.61770755]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for _ in range(10):\n",
|
||||
" print(toy_example2(torch.randn(10), torch.randn(10)))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
92
shark/examples/shark_eager/dynamo_demo.py
Normal file
92
shark/examples/shark_eager/dynamo_demo.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch_mlir import compile, OutputType
|
||||
|
||||
from shark.iree_utils import get_iree_compiled_module
|
||||
|
||||
try:
|
||||
import torchdynamo
|
||||
from torchdynamo.optimizations.backends import create_backend
|
||||
from torchdynamo.optimizations.subgraph import SubGraph
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo"
|
||||
)
|
||||
exit()
|
||||
|
||||
NUM_ITERS = 10
|
||||
|
||||
|
||||
def __torch_mlir(fx_graph, *args, **kwargs):
|
||||
assert isinstance(
|
||||
fx_graph, torch.fx.GraphModule
|
||||
), "Model must be an FX GraphModule."
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):
|
||||
"""Replace tuple with tuple element in functions that return one-element tuples."""
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple) and len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return fx_g
|
||||
|
||||
fx_graph = _unwrap_single_tuple_return(fx_graph)
|
||||
ts_graph = torch.jit.script(fx_graph)
|
||||
|
||||
if isinstance(args, tuple):
|
||||
args = list(args)
|
||||
assert isinstance(args, list)
|
||||
if len(args) == 1 and isinstance(args[0], list):
|
||||
args = args[0]
|
||||
|
||||
linalg_module = compile(
|
||||
ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
linalg_module, "cuda", func_name="forward"
|
||||
)
|
||||
|
||||
def forward(*inputs):
|
||||
return callable(*inputs)
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def toy_example(*args):
|
||||
a, b = args
|
||||
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
with torchdynamo.optimize(__torch_mlir):
|
||||
for _ in range(10):
|
||||
print(toy_example(torch.randn(10), torch.randn(10)))
|
||||
|
||||
|
||||
@create_backend
|
||||
def torch_mlir(subgraph, *args, **kwargs):
|
||||
assert isinstance(subgraph, SubGraph), "Model must be a dynamo SubGraph."
|
||||
return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))
|
||||
|
||||
|
||||
@torchdynamo.optimize("torch_mlir")
|
||||
def toy_example2(*args):
|
||||
a, b = args
|
||||
|
||||
x = a / (torch.abs(a) + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
for _ in range(10):
|
||||
print(toy_example2(torch.randn(10), torch.randn(10)))
|
||||
805
shark/examples/shark_eager/eager_mode.ipynb
Normal file
805
shark/examples/shark_eager/eager_mode.ipynb
Normal file
@@ -0,0 +1,805 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# standard imports\n",
|
||||
"import torch\n",
|
||||
"from torch_mlir.eager_mode import torch_mlir_tensor"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# eager mode imports\n",
|
||||
"from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor\n",
|
||||
"from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The simplest way of using Eager Mode (through IREE) requires setting a \"backend\":"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"cpu\")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"and wrapping all your `torch.Tensor`s:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"NUM_ITERS = 10\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = 2 * torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"tt = TorchMLIRTensor(t)\n",
|
||||
"print(tt)\n",
|
||||
"uu = TorchMLIRTensor(u)\n",
|
||||
"print(uu)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"`TorchMLIRTensor` is a \"tensor wrapper subclass\" (more info [here](https://github.com/albanD/subclass_zoo)) that keeps the IREE `DeviceArray` in a field `elem`:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for i in range(NUM_ITERS):\n",
|
||||
" yy = tt + uu\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())\n",
|
||||
" yy = tt * uu\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"If you have a GPU (and CUDA installed) that works too (you can verify by having `watch -n1 nvidia-smi` up in a terminal while running the next cell):"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"[[3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]\n",
|
||||
" [3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]]\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend(\"gpu\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = 2 * torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"tt = TorchMLIRTensor(t)\n",
|
||||
"print(tt)\n",
|
||||
"uu = TorchMLIRTensor(u)\n",
|
||||
"print(uu)\n",
|
||||
"\n",
|
||||
"yy = tt + uu\n",
|
||||
"print(yy.elem.to_host())\n",
|
||||
"yy = tt * uu\n",
|
||||
"print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"There is a convenience class `SharkEagerMode` that will handle both the installation of the backend and the wrapping of `torch.Tensor`s:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# eager mode RAII\n",
|
||||
"from shark.shark_runner import SharkEagerMode\n",
|
||||
"\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cpu\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"print(t)\n",
|
||||
"print(u)\n",
|
||||
"\n",
|
||||
"for i in range(NUM_ITERS):\n",
|
||||
" yy = t + u\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())\n",
|
||||
" yy = t * u\n",
|
||||
" print(type(yy))\n",
|
||||
" print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The `SharkEagerMode` class is a hacky take on [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization) that defines a \"deleter\" that runs when an instantiation (of `SharkEagerMode`) is garbage collected. Takeaway is that if you want to turn off `SharkEagerMode`, or switch backends, you need to `del` the instance:"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"TorchMLIRTensor(<IREE DeviceArray: shape=[10, 10], dtype=float32>, backend=EagerModeIREELinalgOnTensorsBackend)\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]\n",
|
||||
" [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]\n",
|
||||
"<class 'torch_mlir.eager_mode.torch_mlir_tensor.TorchMLIRTensor'>\n",
|
||||
"[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
|
||||
" [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"del shark_eager_mode\n",
|
||||
"shark_eager_mode = SharkEagerMode(\"cuda\")\n",
|
||||
"\n",
|
||||
"t = torch.ones((10, 10))\n",
|
||||
"u = torch.ones((10, 10))\n",
|
||||
"\n",
|
||||
"print(t)\n",
|
||||
"print(u)\n",
|
||||
"\n",
|
||||
"yy = t + u\n",
|
||||
"print(type(yy))\n",
|
||||
"print(yy.elem.to_host())\n",
|
||||
"yy = t * u\n",
|
||||
"print(type(yy))\n",
|
||||
"print(yy.elem.to_host())"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
148
shark/examples/shark_eager/eager_mode.py
Normal file
148
shark/examples/shark_eager/eager_mode.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load_inline, include_paths
|
||||
from torch_mlir.eager_mode import torch_mlir_tensor
|
||||
from torch_mlir.eager_mode.torch_mlir_tensor import TorchMLIRTensor
|
||||
|
||||
from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
|
||||
from shark.shark_runner import SharkEagerMode
|
||||
|
||||
|
||||
def test_cpu():
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("cpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = 2 * torch.ones((10, 10), device="cpu")
|
||||
|
||||
tt = TorchMLIRTensor(t)
|
||||
print(tt)
|
||||
uu = TorchMLIRTensor(u)
|
||||
print(uu)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = tt + uu
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = tt * uu
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
def test_gpu():
|
||||
source = """
|
||||
#include <iostream>
|
||||
#include "cuda.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
void print_free_mem() {
|
||||
int num_gpus;
|
||||
size_t free, total;
|
||||
cudaSetDevice(0);
|
||||
int id;
|
||||
cudaGetDevice(&id);
|
||||
cudaMemGetInfo(&free, &total);
|
||||
cout << "GPU " << id << " memory: used=" << (total-free)/(1<<20) << endl;
|
||||
}
|
||||
"""
|
||||
gpu_stats = load_inline(
|
||||
name="inline_extension",
|
||||
cpp_sources=[source],
|
||||
extra_include_paths=include_paths(cuda=True),
|
||||
functions=["print_free_mem"],
|
||||
)
|
||||
torch_mlir_tensor.backend = EagerModeIREELinalgOnTensorsBackend("gpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = 2 * torch.ones((10, 10), device="cpu")
|
||||
|
||||
tt = TorchMLIRTensor(t)
|
||||
print(tt)
|
||||
uu = TorchMLIRTensor(u)
|
||||
print(uu)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = tt + uu
|
||||
print(yy.elem.to_host())
|
||||
yy = tt * uu
|
||||
print(yy.elem.to_host())
|
||||
gpu_stats.print_free_mem()
|
||||
|
||||
|
||||
def test_python_mode_ref_backend():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("refbackend")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
print(i)
|
||||
yy = t + u
|
||||
print(yy.elem)
|
||||
yy = t * u
|
||||
print(yy.elem)
|
||||
|
||||
|
||||
def test_python_mode_iree_cpu():
|
||||
# hide this wherever you want?
|
||||
_ = SharkEagerMode("cpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = t + u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = t * u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
def test_python_mode_iree_gpu():
|
||||
_ = SharkEagerMode("gpu")
|
||||
|
||||
t = torch.ones((10, 10), device="cpu")
|
||||
u = torch.ones((10, 10), device="cpu")
|
||||
|
||||
print(t)
|
||||
print(u)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
yy = t + u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
yy = t * u
|
||||
print(type(yy))
|
||||
print(yy.elem.to_host())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NUM_ITERS = 10
|
||||
test_cpu()
|
||||
if torch.cuda.is_available():
|
||||
test_gpu()
|
||||
test_python_mode_ref_backend()
|
||||
test_python_mode_iree_cpu()
|
||||
test_python_mode_iree_gpu()
|
||||
73
shark/examples/shark_eager/squeezenet_lockstep.py
Normal file
73
shark/examples/shark_eager/squeezenet_lockstep.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
model = torch.hub.load(
|
||||
"pytorch/vision:v0.10.0", "squeezenet1_0", pretrained=True
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# from PIL import Image
|
||||
# from torchvision import transforms
|
||||
# import urllib
|
||||
#
|
||||
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
|
||||
# try: urllib.URLopener().retrieve(url, filename)
|
||||
# except: urllib.request.urlretrieve(url, filename)
|
||||
#
|
||||
#
|
||||
# input_image = Image.open(filename)
|
||||
# preprocess = transforms.Compose([
|
||||
# transforms.Resize(256),
|
||||
# transforms.CenterCrop(224),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
# ])
|
||||
# input_tensor = preprocess(input_image)
|
||||
# input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
|
||||
# print(input_batch.shape) # size = [1, 3, 224, 224]
|
||||
|
||||
# The above is code for generating sample inputs from an image. We can just use
|
||||
# random values for accuracy testing though
|
||||
input_batch = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
# Focus on CPU for now
|
||||
if False and torch.cuda.is_available():
|
||||
input_batch = input_batch.to("cuda")
|
||||
model.to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
|
||||
golden_confidences = output[0]
|
||||
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
|
||||
golden_probabilities = torch.nn.functional.softmax(
|
||||
golden_confidences, dim=0
|
||||
).numpy()
|
||||
|
||||
golden_confidences = golden_confidences.numpy()
|
||||
|
||||
from shark.torch_mlir_lockstep_tensor import TorchMLIRLockstepTensor
|
||||
|
||||
input_detached_clone = input_batch.clone()
|
||||
eager_input_batch = TorchMLIRLockstepTensor(input_detached_clone)
|
||||
|
||||
print("getting torch-mlir result")
|
||||
|
||||
output = model(eager_input_batch)
|
||||
|
||||
static_output = output.elem
|
||||
confidences = static_output[0]
|
||||
probabilities = torch.nn.functional.softmax(
|
||||
torch.from_numpy(confidences), dim=0
|
||||
).numpy()
|
||||
|
||||
print("The obtained result via shark is: ", confidences)
|
||||
print("The golden result is:", golden_confidences)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
golden_confidences, confidences, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
golden_probabilities, probabilities, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
65
shark/examples/shark_inference/CLIPModel_tf.py
Normal file
65
shark/examples/shark_inference/CLIPModel_tf.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import CLIPProcessor, TFCLIPModel
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
clip_vit_inputs = [
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[2, 7], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 3, 224, 224], dtype=tf.float32),
|
||||
]
|
||||
|
||||
|
||||
class CLIPModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(CLIPModule, self).__init__()
|
||||
self.m = TFCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
self.m.predict = lambda x, y, z: self.m(
|
||||
input_ids=x, attention_mask=y, pixel_values=z
|
||||
)
|
||||
|
||||
@tf.function(input_signature=clip_vit_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, pixel_values):
|
||||
return self.m.predict(
|
||||
input_ids, attention_mask, pixel_values
|
||||
).logits_per_image
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=image,
|
||||
return_tensors="tf",
|
||||
padding=True,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
CLIPModule(),
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["pixel_values"],
|
||||
)
|
||||
)
|
||||
)
|
||||
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
15
shark/examples/shark_inference/ESRGAN/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
## Running ESRGAN
|
||||
|
||||
```
|
||||
1. pip install numpy opencv-python
|
||||
2. mkdir InputImages
|
||||
(this is where all the input images will reside in)
|
||||
3. mkdir OutputImages
|
||||
(this is where the model will generate all the images)
|
||||
4. mkdir models
|
||||
(save the .pth checkpoint file here)
|
||||
5. python esrgan.py
|
||||
```
|
||||
|
||||
- Download [RRDB_ESRGAN_x4.pth](https://drive.google.com/drive/u/0/folders/17VYV_SoZZesU6mbxz2dMAIccSSlqLecY) and place it in the `models` directory as mentioned above in step 4.
|
||||
- Credits : [ESRGAN](https://github.com/xinntao/ESRGAN)
|
||||
239
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
239
shark/examples/shark_inference/ESRGAN/esrgan.py
Normal file
@@ -0,0 +1,239 @@
|
||||
from ast import arg
|
||||
import os.path as osp
|
||||
import glob
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch_mlir
|
||||
import tempfile
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block"""
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
############### Parsing args #####################
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
p.add_argument("--device", type=str, default="cpu", help="the device to use")
|
||||
p.add_argument(
|
||||
"--mlir_loc",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location of the model's mlir file",
|
||||
)
|
||||
args = p.parse_args()
|
||||
###################################################
|
||||
|
||||
|
||||
def inference(input_m):
|
||||
return model(input_m)
|
||||
|
||||
|
||||
def load_mlir(mlir_loc):
|
||||
import os
|
||||
|
||||
if mlir_loc == None:
|
||||
return None
|
||||
print(f"Trying to load the model from {mlir_loc}.")
|
||||
with open(os.path.join(mlir_loc)) as f:
|
||||
mlir_module = f.read()
|
||||
return mlir_module
|
||||
|
||||
|
||||
def compile_through_fx(model, inputs, mlir_loc=None):
|
||||
module = load_mlir(mlir_loc)
|
||||
if module == None:
|
||||
fx_g = make_fx(
|
||||
model,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
)(inputs)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
print("Torchscript graph generated successfully")
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
inputs,
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
mlir_model = str(module)
|
||||
func_name = "forward"
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
model_path = "models/RRDB_ESRGAN_x4.pth" # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
|
||||
# device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_img_folder = "InputImages/*"
|
||||
|
||||
model = RRDBNet(3, 3, 64, 23, gc=32)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
print("Model path {:s}. \nTesting...".format(model_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
idx = 0
|
||||
for path in glob.glob(test_img_folder):
|
||||
idx += 1
|
||||
base = osp.splitext(osp.basename(path))[0]
|
||||
print(idx, base)
|
||||
# read images
|
||||
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
||||
img = img * 1.0 / 255
|
||||
img = torch.from_numpy(
|
||||
np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
).float()
|
||||
img_LR = img.unsqueeze(0)
|
||||
img_LR = img_LR.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
shark_module = compile_through_fx(inference, img_LR)
|
||||
shark_output = shark_module.forward((img_LR,))
|
||||
shark_output = torch.from_numpy(shark_output)
|
||||
shark_output = (
|
||||
shark_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
esrgan_output = (
|
||||
model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
)
|
||||
# SHARK OUTPUT
|
||||
shark_output = np.transpose(shark_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
shark_output = (shark_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_shark_output.png".format(base), shark_output
|
||||
)
|
||||
print("Generated SHARK's output")
|
||||
# ESRGAN OUTPUT
|
||||
esrgan_output = np.transpose(esrgan_output[[2, 1, 0], :, :], (1, 2, 0))
|
||||
esrgan_output = (esrgan_output * 255.0).round()
|
||||
cv2.imwrite(
|
||||
"OutputImages/{:s}_rlt_esrgan_output.png".format(base),
|
||||
esrgan_output,
|
||||
)
|
||||
print("Generated ESRGAN's output")
|
||||
86
shark/examples/shark_inference/albert_maskfill_pt.py
Normal file
86
shark/examples/shark_inference/albert_maskfill_pt.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
class AlbertModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.model(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="torch",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
encoded_inputs["input_ids"] == tokenizer.mask_token_id
|
||||
)[1]
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = (
|
||||
torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
|
||||
)
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
100
shark/examples/shark_inference/albert_maskfill_tf.py
Normal file
100
shark/examples/shark_inference/albert_maskfill_tf.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import TFAutoModelForMaskedLM, AutoTokenizer
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
from iree.compiler import tf as tfc
|
||||
from iree.compiler import compile_str
|
||||
from iree import runtime as ireert
|
||||
import os
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class AlbertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(AlbertModule, self).__init__()
|
||||
self.m = TFAutoModelForMaskedLM.from_pretrained("albert-base-v2")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
# text = "This is a great [MASK]."
|
||||
text = "This [MASK] is very tasty."
|
||||
encoded_inputs = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (encoded_inputs["input_ids"], encoded_inputs["attention_mask"])
|
||||
mlir_importer = SharkImporter(
|
||||
AlbertModule(),
|
||||
inputs,
|
||||
frontend="tf",
|
||||
)
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=False
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir, mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
output_idx = 0
|
||||
data_idx = 1
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"]) == tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[0:5]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> Sample/Warmup output: {text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence with [MASK] to fill: ")
|
||||
encoded_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
return_tensors="tf",
|
||||
)
|
||||
inputs = (
|
||||
encoded_inputs["input_ids"],
|
||||
encoded_inputs["attention_mask"],
|
||||
)
|
||||
token_logits = shark_module.forward(inputs)[output_idx][data_idx]
|
||||
mask_id = np.where(
|
||||
tf.squeeze(encoded_inputs["input_ids"])
|
||||
== tokenizer.mask_token_id
|
||||
)
|
||||
mask_token_logits = token_logits[0, mask_id, :]
|
||||
top_5_tokens = np.flip(np.argsort(mask_token_logits)).squeeze()[
|
||||
0:5
|
||||
]
|
||||
for token in top_5_tokens:
|
||||
print(
|
||||
f"'>>> {new_text.replace(tokenizer.mask_token, tokenizer.decode(token))}'"
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
sys.exit()
|
||||
14
shark/examples/shark_inference/bloom_tank.py
Normal file
14
shark/examples/shark_inference/bloom_tank.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"bloom", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device="cpu", mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
40
shark/examples/shark_inference/gpt2_tf.py
Normal file
40
shark/examples/shark_inference/gpt2_tf.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import GPT2Tokenizer, TFGPT2Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
gpt2_inputs = [
|
||||
tf.TensorSpec(shape=[1, 8], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 8], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class GPT2Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(GPT2Module, self).__init__()
|
||||
self.m = TFGPT2Model.from_pretrained("distilgpt2")
|
||||
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, attention_mask=y)
|
||||
|
||||
@tf.function(input_signature=gpt2_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.m.predict(input_ids, attention_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||
text = "I love the distilled version of models."
|
||||
|
||||
inputs = tokenizer(text, return_tensors="tf")
|
||||
shark_module = SharkInference(
|
||||
GPT2Module(), (inputs["input_ids"], inputs["attention_mask"])
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(
|
||||
shark_module.forward((inputs["input_ids"], inputs["attention_mask"]))
|
||||
)
|
||||
18
shark/examples/shark_inference/llama/README.md
Normal file
18
shark/examples/shark_inference/llama/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# SHARK LLaMA
|
||||
|
||||
## TORCH-MLIR Version
|
||||
|
||||
```
|
||||
https://github.com/nod-ai/torch-mlir.git
|
||||
```
|
||||
Then check out the `complex` branch and `git submodule update --init` and then build with `.\build_tools\python_deploy\build_windows.ps1`
|
||||
|
||||
### Setup & Run
|
||||
```
|
||||
git clone https://github.com/nod-ai/llama.git
|
||||
```
|
||||
Then in this repository
|
||||
```
|
||||
pip install -e .
|
||||
python llama/shark_model.py
|
||||
```
|
||||
72
shark/examples/shark_inference/mega_test.py
Normal file
72
shark/examples/shark_inference/mega_test.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_compile import shark_compile_through_fx
|
||||
from MEGABYTE_pytorch import MEGABYTE
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MegaModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MEGABYTE(
|
||||
num_tokens=16000, # number of tokens
|
||||
dim=(
|
||||
512,
|
||||
256,
|
||||
), # transformer model dimension (512 for coarsest, 256 for fine in this example)
|
||||
max_seq_len=(
|
||||
1024,
|
||||
4,
|
||||
), # sequence length for global and then local. this can be more than 2
|
||||
depth=(
|
||||
6,
|
||||
4,
|
||||
), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
|
||||
dim_head=64, # dimension per head
|
||||
heads=8, # number of attention heads
|
||||
flash_attn=True, # use flash attention
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
megaModel = MegaModel()
|
||||
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
|
||||
|
||||
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
|
||||
# 1. aten.alias
|
||||
shark_module, _ = shark_compile_through_fx(
|
||||
model=megaModel,
|
||||
inputs=inputs,
|
||||
extended_model_name="mega_shark",
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
save_dir=os.getcwd(),
|
||||
debug=False,
|
||||
generate_or_load_vmfb=True,
|
||||
extra_args=[],
|
||||
device="cuda",
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
# logits = model(x)
|
||||
|
||||
|
||||
def print_output_info(output, msg):
|
||||
print("\n", msg)
|
||||
print("\n\t", output.shape)
|
||||
|
||||
|
||||
ans = shark_module("forward", inputs)
|
||||
print_output_info(torch.from_numpy(ans), "SHARK's output")
|
||||
|
||||
ans = megaModel.forward(*inputs)
|
||||
print_output_info(ans, "ORIGINAL Model's output")
|
||||
|
||||
# and sample from the logits accordingly
|
||||
# or you can use the generate function
|
||||
|
||||
# NEED TO LOOK AT THIS LATER IF REQUIRED IN SHARK.
|
||||
# sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
|
||||
31
shark/examples/shark_inference/mhlo_example.py
Normal file
31
shark/examples/shark_inference/mhlo_example.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
import numpy as np
|
||||
|
||||
mhlo_ir = r"""builtin.module {
|
||||
func.func @forward(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<4x4xf32>
|
||||
%1 = "mhlo.abs"(%0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
}"""
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
35
shark/examples/shark_inference/minilm_benchmark.py
Normal file
35
shark/examples/shark_inference/minilm_benchmark.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
test_input = torch.randint(2, (1, 128))
|
||||
|
||||
shark_module = SharkInference(
|
||||
MiniLMSequenceClassification(),
|
||||
(test_input,),
|
||||
jit_trace=True,
|
||||
benchmark_mode=True,
|
||||
)
|
||||
|
||||
shark_module.compile()
|
||||
shark_module.forward((test_input,))
|
||||
shark_module.benchmark_all((test_input,))
|
||||
61
shark/examples/shark_inference/minilm_benchmark_tf.py
Normal file
61
shark/examples/shark_inference/minilm_benchmark_tf.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
test_input = (
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
BertModule(), test_input, benchmark_mode=True
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
shark_module.benchmark_all(test_input)
|
||||
73
shark/examples/shark_inference/minilm_jax.py
Normal file
73
shark/examples/shark_inference/minilm_jax.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from transformers import AutoTokenizer, FlaxAutoModel
|
||||
import torch
|
||||
import jax
|
||||
from typing import Union, Dict, List, Any
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
import io
|
||||
|
||||
NumpyTree = Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]]
|
||||
|
||||
|
||||
def convert_torch_tensor_tree_to_numpy(
|
||||
tree: Union[torch.tensor, Dict[str, torch.tensor], List[torch.tensor]]
|
||||
) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda torch_tensor: torch_tensor.cpu().detach().numpy(), tree
|
||||
)
|
||||
|
||||
|
||||
def convert_int64_to_int32(tree: NumpyTree) -> NumpyTree:
|
||||
return jax.tree_util.tree_map(
|
||||
lambda tensor: np.array(tensor, dtype=np.int32)
|
||||
if tensor.dtype == np.int64
|
||||
else tensor,
|
||||
tree,
|
||||
)
|
||||
|
||||
|
||||
def get_sample_input():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
inputs_torch = tokenizer("Hello, World!", return_tensors="pt")
|
||||
return convert_int64_to_int32(
|
||||
convert_torch_tensor_tree_to_numpy(inputs_torch.data)
|
||||
)
|
||||
|
||||
|
||||
def get_jax_model():
|
||||
return FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
|
||||
|
||||
def export_jax_to_mlir(jax_model: Any, sample_input: NumpyTree):
|
||||
model_mlir = jax.jit(jax_model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
|
||||
|
||||
def assert_array_list_allclose(x, y, *args, **kwargs):
|
||||
assert len(x) == len(y)
|
||||
for a, b in zip(x, y):
|
||||
np.testing.assert_allclose(
|
||||
np.asarray(a), np.asarray(b), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
jax_model = get_jax_model()
|
||||
mlir = export_jax_to_mlir(jax_model, sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
result = shark_inference("main", jax.tree_util.tree_flatten(sample_input)[0])
|
||||
|
||||
# Run JAX model.
|
||||
reference_result = jax.tree_util.tree_flatten(jax_model(**sample_input))[0]
|
||||
|
||||
# Verify result.
|
||||
assert_array_list_allclose(result, reference_result, atol=1e-5)
|
||||
@@ -0,0 +1,6 @@
|
||||
flax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
orbax
|
||||
transformers
|
||||
torch
|
||||
23
shark/examples/shark_inference/minilm_jit.py
Normal file
23
shark/examples/shark_inference/minilm_jit.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"microsoft/MiniLM-L12-H384-uncased",
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
|
||||
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
|
||||
|
||||
# Let's generate random inputs, currently supported
|
||||
# for static models.
|
||||
rand_inputs = shark_module.generate_random_inputs()
|
||||
rand_results = shark_module.forward(rand_inputs)
|
||||
|
||||
print("Running shark_module with random_inputs is: ", rand_results)
|
||||
70
shark/examples/shark_inference/minilm_tf.py
Normal file
70
shark/examples/shark_inference/minilm_tf.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import tensorflow as tf
|
||||
from transformers import BertModel, BertTokenizer, TFBertModel
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Create a set of 2-dimensional inputs
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, MAX_SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
# Create a BERT trainer with the created network.
|
||||
self.m = TFBertModel.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", from_pt=True
|
||||
)
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m.predict = lambda x, y, z: self.m.call(
|
||||
input_ids=x, attention_mask=y, token_type_ids=z, training=False
|
||||
)
|
||||
|
||||
@tf.function(input_signature=bert_input, jit_compile=True)
|
||||
def forward(self, input_ids, attention_mask, token_type_ids):
|
||||
return self.m.predict(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased"
|
||||
)
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
)
|
||||
for key in encoded_input:
|
||||
encoded_input[key] = tf.expand_dims(
|
||||
tf.convert_to_tensor(encoded_input[key]), 0
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
BertModule(),
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
|
||||
print(
|
||||
shark_module.forward(
|
||||
(
|
||||
encoded_input["input_ids"],
|
||||
encoded_input["attention_mask"],
|
||||
encoded_input["token_type_ids"],
|
||||
)
|
||||
)
|
||||
)
|
||||
1
shark/examples/shark_inference/minilm_tf_gpu_config.json
Normal file
1
shark/examples/shark_inference/minilm_tf_gpu_config.json
Normal file
File diff suppressed because one or more lines are too long
39
shark/examples/shark_inference/resnest.py
Normal file
39
shark/examples/shark_inference/resnest.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
|
||||
|
||||
|
||||
class ResnestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"zhanghang1989/ResNeSt", "resnest50", pretrained=True
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
ResnestModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
print(golden_out)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
print("Obtained result", result)
|
||||
74
shark/examples/shark_inference/resnet50_fp16.py
Normal file
74
shark/examples/shark_inference/resnet50_fp16.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.parser import shark_args
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
import torchvision.models as models
|
||||
import torch_mlir
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
model = VisionModule()
|
||||
test_input = torch.randn(1, 3, 224, 224)
|
||||
actual_out = model(test_input)
|
||||
|
||||
test_input_fp16 = test_input.to(device=torch.device("cuda"), dtype=torch.half)
|
||||
model_fp16 = model.half()
|
||||
model_fp16.eval()
|
||||
model_fp16.to("cuda")
|
||||
actual_out_fp16 = model_fp16(test_input_fp16)
|
||||
|
||||
ts_g = torch.jit.trace(model_fp16, [test_input_fp16])
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(test_input_fp16),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=True,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# from contextlib import redirect_stdout
|
||||
|
||||
# with open('resnet50_fp16_linalg_ir.mlir', 'w') as f:
|
||||
# with redirect_stdout(f):
|
||||
# print(module.operation.get_asm())
|
||||
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
def shark_result(x):
|
||||
x_ny = x.cpu().detach().numpy()
|
||||
inputs = (x_ny,)
|
||||
result = shark_module.forward(inputs)
|
||||
return torch.from_numpy(result)
|
||||
|
||||
|
||||
observed_out = shark_result(test_input_fp16)
|
||||
|
||||
print("Golden result:", actual_out_fp16)
|
||||
print("SHARK result:", observed_out)
|
||||
|
||||
actual_out_fp16 = actual_out_fp16.to(device=torch.device("cpu"))
|
||||
|
||||
print(
|
||||
torch.testing.assert_allclose(
|
||||
actual_out_fp16, observed_out, rtol=1e-2, atol=1e-2
|
||||
)
|
||||
)
|
||||
85
shark/examples/shark_inference/resnet50_script.py
Normal file
85
shark/examples/shark_inference/resnet50_script.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
################################## Preprocessing inputs and model ############
|
||||
def load_and_preprocess_image(url: str):
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36"
|
||||
}
|
||||
img = Image.open(
|
||||
requests.get(url, headers=headers, stream=True).raw
|
||||
).convert("RGB")
|
||||
# preprocessing pipeline
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
]
|
||||
)
|
||||
img_preprocessed = preprocess(img)
|
||||
return torch.unsqueeze(img_preprocessed, 0)
|
||||
|
||||
|
||||
def load_labels():
|
||||
classes_text = requests.get(
|
||||
"https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt",
|
||||
stream=True,
|
||||
).text
|
||||
labels = [line.strip() for line in classes_text.splitlines()]
|
||||
return labels
|
||||
|
||||
|
||||
def top3_possibilities(res):
|
||||
_, indexes = torch.sort(res, descending=True)
|
||||
percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100
|
||||
top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
|
||||
return top3
|
||||
|
||||
|
||||
class Resnet50Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.resnet = models.resnet50(pretrained=True)
|
||||
self.train(False)
|
||||
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
||||
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
|
||||
print("load image from " + image_url, file=sys.stderr)
|
||||
img = load_and_preprocess_image(image_url)
|
||||
labels = load_labels()
|
||||
|
||||
##############################################################################
|
||||
|
||||
|
||||
## Can pass any img or input to the forward module.
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"resnet50", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(mlir_model, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
path = shark_module.save_module()
|
||||
shark_module.load_module(path)
|
||||
result = shark_module("forward", (img.detach().numpy(),))
|
||||
|
||||
print("The top 3 results obtained via shark_runner is:")
|
||||
print(top3_possibilities(torch.from_numpy(result)))
|
||||
|
||||
print()
|
||||
|
||||
print("The top 3 results obtained via torch is:")
|
||||
print(top3_possibilities(Resnet50Module()(img)))
|
||||
842
shark/examples/shark_inference/sharded_bloom.py
Normal file
842
shark/examples/shark_inference/sharded_bloom.py
Normal file
@@ -0,0 +1,842 @@
|
||||
####################################################################################
|
||||
# Please make sure you have transformers 4.21.2 installed before running this demo
|
||||
#
|
||||
# -p --model_path: the directory in which you want to store the bloom files.
|
||||
# -dl --device_list: the list of device indices you want to use. if you want to only use the first device, or you are running on cpu leave this blank.
|
||||
# Otherwise, please give this argument in this format: "[0, 1, 2]"
|
||||
# -de --device: the device you want to run bloom on. E.G. cpu, cuda
|
||||
# -c, --recompile: set to true if you want to recompile to vmfb.
|
||||
# -d, --download: set to true if you want to redownload the mlir files
|
||||
# -cm, --create_mlirs: set to true if you want to create the mlir files from scratch. please make sure you have transformers 4.21.2 before using this option
|
||||
# -t --token_count: the number of tokens you want to generate
|
||||
# -pr --prompt: the prompt you want to feed to the model
|
||||
# -m --model_name: the name of the model, e.g. bloom-560m
|
||||
#
|
||||
# If you don't specify a prompt when you run this example, you will be able to give prompts through the terminal. Run the
|
||||
# example in this way if you want to run multiple examples without reinitializing the model
|
||||
#####################################################################################
|
||||
|
||||
import os
|
||||
import io
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
import re
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig
|
||||
import json
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import urllib.request
|
||||
import subprocess
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from transformers import (
|
||||
BloomTokenizerFast,
|
||||
BloomForSequenceClassification,
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
|
||||
IS_CUDA = False
|
||||
|
||||
|
||||
class ShardedBloom:
|
||||
def __init__(self, src_folder):
|
||||
f = open(f"{src_folder}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self.layers_initialized = False
|
||||
|
||||
self.src_folder = src_folder
|
||||
try:
|
||||
self.n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
self.n_embed = config["hidden_size"]
|
||||
self.vocab_size = config["vocab_size"]
|
||||
self.n_layer = config["n_layer"]
|
||||
try:
|
||||
self.n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
self.n_head = config["n_head"]
|
||||
|
||||
def _init_layer(self, layer_name, device, replace, device_idx):
|
||||
if replace or not os.path.exists(
|
||||
f"{self.src_folder}/{layer_name}.vmfb"
|
||||
):
|
||||
f_ = open(f"{self.src_folder}/{layer_name}.mlir", encoding="utf-8")
|
||||
module = f_.read()
|
||||
f_.close()
|
||||
module = bytes(module, "utf-8")
|
||||
shark_module = SharkInference(
|
||||
module,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
shark_module.save_module(
|
||||
module_name=f"{self.src_folder}/{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
"",
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
return shark_module
|
||||
|
||||
def init_layers(self, device, replace=False, device_idx=[0]):
|
||||
if device_idx is not None:
|
||||
n_devices = len(device_idx)
|
||||
|
||||
self.word_embeddings_module = self._init_layer(
|
||||
"word_embeddings",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[0 % n_devices],
|
||||
)
|
||||
self.word_embeddings_layernorm_module = self._init_layer(
|
||||
"word_embeddings_layernorm",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[1 % n_devices],
|
||||
)
|
||||
self.ln_f_module = self._init_layer(
|
||||
"ln_f",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[2 % n_devices],
|
||||
)
|
||||
self.lm_head_module = self._init_layer(
|
||||
"lm_head",
|
||||
device,
|
||||
replace,
|
||||
device_idx if device_idx is None else device_idx[3 % n_devices],
|
||||
)
|
||||
self.block_modules = [
|
||||
self._init_layer(
|
||||
f"bloom_block_{i}",
|
||||
device,
|
||||
replace,
|
||||
device_idx
|
||||
if device_idx is None
|
||||
else device_idx[(i + 4) % n_devices],
|
||||
)
|
||||
for i in range(self.n_layer)
|
||||
]
|
||||
|
||||
self.layers_initialized = True
|
||||
|
||||
def load_layers(self):
|
||||
assert self.layers_initialized
|
||||
|
||||
self.word_embeddings_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings.vmfb"
|
||||
)
|
||||
self.word_embeddings_layernorm_module.load_module(
|
||||
f"{self.src_folder}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
block_module.load_module(f"{self.src_folder}/bloom_block_{i}.vmfb")
|
||||
self.ln_f_module.load_module(f"{self.src_folder}/ln_f.vmfb")
|
||||
self.lm_head_module.load_module(f"{self.src_folder}/lm_head.vmfb")
|
||||
|
||||
def forward_pass(self, input_ids, device):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_module.device_idx)
|
||||
|
||||
input_embeds = self.word_embeddings_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.word_embeddings_layernorm_module.device_idx)
|
||||
hidden_states = self.word_embeddings_layernorm_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
self.n_head,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
presents = ()
|
||||
all_hidden_states = tuple(hidden_states)
|
||||
|
||||
for block_module, i in zip(self.block_modules, range(self.n_layer)):
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(block_module.device_idx)
|
||||
|
||||
output = block_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
presents = presents + (
|
||||
tuple(
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
)
|
||||
),
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.ln_f_module.device_idx)
|
||||
|
||||
hidden_states = self.ln_f_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
if IS_CUDA:
|
||||
cudaSetDevice(self.lm_head_module.device_idx)
|
||||
|
||||
logits = self.lm_head_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
return torch.argmax(logits[:, -1, :], dim=-1)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def download_model(destination_folder, model_name):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/sharded_bloom/{model_name}/", destination_folder
|
||||
)
|
||||
|
||||
|
||||
def compile_embeddings(embeddings_layer, input_ids, path):
|
||||
input_ids_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer,
|
||||
(input_ids_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_word_embeddings_layernorm(
|
||||
embeddings_layer_layernorm, embeds, path
|
||||
):
|
||||
embeds_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
embeds, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
embeddings_layer_layernorm,
|
||||
(embeds_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def strip_overloads(gm):
|
||||
"""
|
||||
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
|
||||
Args:
|
||||
gm(fx.GraphModule): The input Fx graph module to be modified
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, torch._ops.OpOverload):
|
||||
node.target = node.target.overloadpacket
|
||||
gm.recompile()
|
||||
|
||||
|
||||
def compile_to_mlir(
|
||||
bblock,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
alibi=None,
|
||||
block_index=0,
|
||||
path=".",
|
||||
):
|
||||
fx_g = make_fx(
|
||||
bblock,
|
||||
decomposition_table=get_decompositions(
|
||||
[
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
),
|
||||
tracing_mode="real",
|
||||
_allow_non_fake_inputs=False,
|
||||
)(hidden_states, alibi, attention_mask)
|
||||
|
||||
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
fx_g.recompile()
|
||||
|
||||
strip_overloads(fx_g)
|
||||
|
||||
hidden_states_placeholder = TensorPlaceholder.like(
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
attention_mask_placeholder = TensorPlaceholder.like(
|
||||
attention_mask, dynamic_axes=[2, 3]
|
||||
)
|
||||
alibi_placeholder = TensorPlaceholder.like(alibi, dynamic_axes=[2])
|
||||
|
||||
ts_g = torch.jit.script(fx_g)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_g,
|
||||
(
|
||||
hidden_states_placeholder,
|
||||
alibi_placeholder,
|
||||
attention_mask_placeholder,
|
||||
),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
module_placeholder = module
|
||||
module_context = module_placeholder.context
|
||||
|
||||
def check_valid_line(line, line_n, mlir_file_len):
|
||||
if "private" in line:
|
||||
return False
|
||||
if "attributes" in line:
|
||||
return False
|
||||
if mlir_file_len - line_n == 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
mlir_file_len = len(str(module).split("\n"))
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "17x" in line:
|
||||
line = re.sub("17x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi eq" in line:
|
||||
line = re.sub("c17", "dim", line)
|
||||
if " 17," in line:
|
||||
line = re.sub(" 17,", " %dim,", line)
|
||||
return line
|
||||
|
||||
module = "\n".join(
|
||||
[
|
||||
remove_constant_dim(line)
|
||||
for line, line_n in zip(
|
||||
str(module).split("\n"), range(mlir_file_len)
|
||||
)
|
||||
if check_valid_line(line, line_n, mlir_file_len)
|
||||
]
|
||||
)
|
||||
|
||||
module = module_placeholder.parse(module, context=module_context)
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_ln_f(ln_f, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ln_f,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def compile_lm_head(lm_head, hidden_layers, path):
|
||||
hidden_layers_placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
hidden_layers, dynamic_axes=[1]
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
lm_head,
|
||||
(hidden_layers_placeholder),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
f_ = open(path, "w+")
|
||||
f_.write(str(module))
|
||||
f_.close()
|
||||
return
|
||||
|
||||
|
||||
def create_mlirs(destination_folder, model_name):
|
||||
model_config = "bigscience/" + model_name
|
||||
sample_input_ids = torch.ones([1, 17], dtype=torch.int64)
|
||||
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/{model_name}/resolve/main/config.json",
|
||||
filename=f"{destination_folder}/config.json",
|
||||
)
|
||||
urllib.request.urlretrieve(
|
||||
f"https://huggingface.co/bigscience/bloom/resolve/main/tokenizer.json",
|
||||
filename=f"{destination_folder}/tokenizer.json",
|
||||
)
|
||||
|
||||
class HuggingFaceLanguage(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = BloomForCausalLM.from_pretrained(model_config)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
class HuggingFaceBlock(torch.nn.Module):
|
||||
def __init__(self, block):
|
||||
super().__init__()
|
||||
self.model = block
|
||||
|
||||
def forward(self, tokens, alibi, attention_mask):
|
||||
output = self.model(
|
||||
hidden_states=tokens,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
model = HuggingFaceLanguage()
|
||||
|
||||
compile_embeddings(
|
||||
model.model.transformer.word_embeddings,
|
||||
sample_input_ids,
|
||||
f"{destination_folder}/word_embeddings.mlir",
|
||||
)
|
||||
|
||||
inputs_embeds = model.model.transformer.word_embeddings(sample_input_ids)
|
||||
|
||||
compile_word_embeddings_layernorm(
|
||||
model.model.transformer.word_embeddings_layernorm,
|
||||
inputs_embeds,
|
||||
f"{destination_folder}/word_embeddings_layernorm.mlir",
|
||||
)
|
||||
|
||||
hidden_states = model.model.transformer.word_embeddings_layernorm(
|
||||
inputs_embeds
|
||||
)
|
||||
|
||||
input_shape = sample_input_ids.size()
|
||||
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
past_key_values_length = 0
|
||||
past_key_values = tuple([None] * len(model.model.transformer.h))
|
||||
|
||||
attention_mask = torch.ones(
|
||||
(hidden_states.shape[0], current_sequence_length), device="cpu"
|
||||
)
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
model.model.transformer.n_head,
|
||||
hidden_states.dtype,
|
||||
"cpu",
|
||||
)
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
head_mask = model.model.transformer.get_head_mask(
|
||||
None, model.model.transformer.config.n_layer
|
||||
)
|
||||
output_attentions = model.model.transformer.config.output_attentions
|
||||
|
||||
all_hidden_states = ()
|
||||
|
||||
for i, (block, layer_past) in enumerate(
|
||||
zip(model.model.transformer.h, past_key_values)
|
||||
):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
proxy_model = HuggingFaceBlock(block)
|
||||
|
||||
compile_to_mlir(
|
||||
proxy_model,
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=True,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
block_index=i,
|
||||
path=f"{destination_folder}/bloom_block_{i}.mlir",
|
||||
)
|
||||
|
||||
compile_ln_f(
|
||||
model.model.transformer.ln_f,
|
||||
hidden_states,
|
||||
f"{destination_folder}/ln_f.mlir",
|
||||
)
|
||||
hidden_states = model.model.transformer.ln_f(hidden_states)
|
||||
compile_lm_head(
|
||||
model.model.lm_head,
|
||||
hidden_states,
|
||||
f"{destination_folder}/lm_head.mlir",
|
||||
)
|
||||
|
||||
|
||||
def run_large_model(
|
||||
token_count,
|
||||
recompile,
|
||||
model_path,
|
||||
prompt,
|
||||
device_list,
|
||||
script_path,
|
||||
device,
|
||||
):
|
||||
f = open(f"{model_path}/prompt.txt", "w+")
|
||||
f.write(prompt)
|
||||
f.close()
|
||||
for i in range(token_count):
|
||||
if i == 0:
|
||||
will_compile = recompile
|
||||
else:
|
||||
will_compile = False
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
prompt = f.read()
|
||||
f.close()
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"start",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
for i in range(config["n_layer"]):
|
||||
if device_list is not None:
|
||||
device_idx = str(device_list[i % len(device_list)])
|
||||
else:
|
||||
device_idx = "None"
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
str(i),
|
||||
str(will_compile),
|
||||
device,
|
||||
device_idx,
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
script_path,
|
||||
model_path,
|
||||
"end",
|
||||
str(will_compile),
|
||||
"cpu",
|
||||
"None",
|
||||
prompt,
|
||||
]
|
||||
)
|
||||
|
||||
f = open(f"{model_path}/prompt.txt", "r")
|
||||
output = f.read()
|
||||
f.close()
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(prog="Bloom-560m")
|
||||
parser.add_argument("-p", "--model_path")
|
||||
parser.add_argument("-dl", "--device_list", default=None)
|
||||
parser.add_argument("-de", "--device", default="cpu")
|
||||
parser.add_argument("-c", "--recompile", default=False, type=bool)
|
||||
parser.add_argument("-d", "--download", default=False, type=bool)
|
||||
parser.add_argument("-t", "--token_count", default=10, type=int)
|
||||
parser.add_argument("-m", "--model_name", default="bloom-560m")
|
||||
parser.add_argument("-cm", "--create_mlirs", default=False, type=bool)
|
||||
|
||||
parser.add_argument(
|
||||
"-lm", "--large_model_memory_efficient", default=False, type=bool
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-pr",
|
||||
"--prompt",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.create_mlirs and args.large_model_memory_efficient:
|
||||
print(
|
||||
"Warning: If you need to use memory efficient mode, you probably want to use 'download' instead"
|
||||
)
|
||||
|
||||
if not os.path.isdir(args.model_path):
|
||||
os.mkdir(args.model_path)
|
||||
|
||||
if args.device_list is not None:
|
||||
args.device_list = json.loads(args.device_list)
|
||||
|
||||
if args.device == "cuda" and args.device_list is not None:
|
||||
IS_CUDA = True
|
||||
from cuda.cudart import cudaSetDevice
|
||||
if args.download and args.create_mlirs:
|
||||
print(
|
||||
"WARNING: It is not advised to turn on both download and create_mlirs"
|
||||
)
|
||||
if args.download:
|
||||
download_model(args.model_path, args.model_name)
|
||||
if args.create_mlirs:
|
||||
create_mlirs(args.model_path, args.model_name)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
if args.prompt is not None:
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
|
||||
if args.large_model_memory_efficient:
|
||||
f = open(f"{args.model_path}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
self_path = os.path.dirname(os.path.abspath(__file__))
|
||||
script_path = os.path.join(self_path, "sharded_bloom_large_models.py")
|
||||
|
||||
if args.prompt is not None:
|
||||
run_large_model(
|
||||
args.token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
args.prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
run_large_model(
|
||||
token_count,
|
||||
args.recompile,
|
||||
args.model_path,
|
||||
prompt,
|
||||
args.device_list,
|
||||
script_path,
|
||||
args.device,
|
||||
)
|
||||
|
||||
else:
|
||||
shardedbloom = ShardedBloom(args.model_path)
|
||||
shardedbloom.init_layers(
|
||||
device=args.device,
|
||||
replace=args.recompile,
|
||||
device_idx=args.device_list,
|
||||
)
|
||||
shardedbloom.load_layers()
|
||||
|
||||
if args.prompt is not None:
|
||||
for _ in range(args.token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
|
||||
else:
|
||||
while True:
|
||||
prompt = input("Enter Prompt: ")
|
||||
try:
|
||||
token_count = int(
|
||||
input("Enter number of tokens you want to generate: ")
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"Invalid integer entered. Using default value of 10"
|
||||
)
|
||||
token_count = 10
|
||||
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
for _ in range(token_count):
|
||||
next_token = shardedbloom.forward_pass(
|
||||
torch.tensor(input_ids), device=args.device
|
||||
)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, next_token.unsqueeze(-1)], dim=-1
|
||||
)
|
||||
|
||||
print(tokenizer.decode(input_ids.squeeze()))
|
||||
381
shark/examples/shark_inference/sharded_bloom_large_models.py
Normal file
381
shark/examples/shark_inference/sharded_bloom_large_models.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import sys
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomConfig
|
||||
import re
|
||||
from shark.shark_inference import SharkInference
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BloomBlock,
|
||||
build_alibi_tensor,
|
||||
)
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :]
|
||||
.expand(batch_size, 1, tgt_len, source_length)
|
||||
.to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _prepare_attn_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
past_key_values_length=past_key_values_length,
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
target_length, past_key_values_length, dtype=dtype
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
expanded_mask = mask[None, None, :, :].expand(
|
||||
batch_size, 1, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
working_dir = sys.argv[1]
|
||||
layer_name = sys.argv[2]
|
||||
will_compile = sys.argv[3]
|
||||
device = sys.argv[4]
|
||||
device_idx = sys.argv[5]
|
||||
prompt = sys.argv[6]
|
||||
|
||||
if device_idx.lower().strip() == "none":
|
||||
device_idx = None
|
||||
else:
|
||||
device_idx = int(device_idx)
|
||||
|
||||
if will_compile.lower().strip() == "true":
|
||||
will_compile = True
|
||||
else:
|
||||
will_compile = False
|
||||
|
||||
f = open(f"{working_dir}/config.json")
|
||||
config = json.load(f)
|
||||
f.close()
|
||||
|
||||
layers_initialized = False
|
||||
try:
|
||||
n_embed = config["n_embed"]
|
||||
except KeyError:
|
||||
n_embed = config["hidden_size"]
|
||||
vocab_size = config["vocab_size"]
|
||||
n_layer = config["n_layer"]
|
||||
try:
|
||||
n_head = config["num_attention_heads"]
|
||||
except KeyError:
|
||||
n_head = config["n_head"]
|
||||
|
||||
if not os.path.isdir(working_dir):
|
||||
os.mkdir(working_dir)
|
||||
|
||||
if layer_name == "start":
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/word_embeddings.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/word_embeddings.vmfb")
|
||||
input_embeds = shark_module(
|
||||
inputs=(input_ids,), function_name="forward"
|
||||
)
|
||||
input_embeds = torch.tensor(input_embeds).float()
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/word_embeddings_layernorm.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/word_embeddings_layernorm",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
f"{working_dir}/word_embeddings_layernorm.vmfb"
|
||||
)
|
||||
hidden_states = shark_module(
|
||||
inputs=(input_embeds,), function_name="forward"
|
||||
)
|
||||
hidden_states = torch.tensor(hidden_states).float()
|
||||
|
||||
torch.save(hidden_states, f"{working_dir}/hidden_states_0.pt")
|
||||
|
||||
attention_mask = torch.ones(
|
||||
[hidden_states.shape[0], len(input_ids[0])]
|
||||
)
|
||||
|
||||
attention_mask = torch.tensor(attention_mask).float()
|
||||
|
||||
alibi = build_alibi_tensor(
|
||||
attention_mask,
|
||||
n_head,
|
||||
hidden_states.dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
torch.save(alibi, f"{working_dir}/alibi.pt")
|
||||
|
||||
causal_mask = _prepare_attn_mask(
|
||||
attention_mask, input_ids.size(), input_embeds, 0
|
||||
)
|
||||
causal_mask = torch.tensor(causal_mask).float()
|
||||
|
||||
torch.save(causal_mask, f"{working_dir}/causal_mask.pt")
|
||||
|
||||
elif layer_name in [str(x) for x in range(n_layer)]:
|
||||
hidden_states = torch.load(
|
||||
f"{working_dir}/hidden_states_{layer_name}.pt"
|
||||
)
|
||||
alibi = torch.load(f"{working_dir}/alibi.pt")
|
||||
causal_mask = torch.load(f"{working_dir}/causal_mask.pt")
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(
|
||||
f"{working_dir}/bloom_block_{layer_name}.mlir",
|
||||
encoding="utf-8",
|
||||
)
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/bloom_block_{layer_name}",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(
|
||||
f"{working_dir}/bloom_block_{layer_name}.vmfb"
|
||||
)
|
||||
|
||||
output = shark_module(
|
||||
inputs=(
|
||||
hidden_states.detach().numpy(),
|
||||
alibi.detach().numpy(),
|
||||
causal_mask.detach().numpy(),
|
||||
),
|
||||
function_name="forward",
|
||||
)
|
||||
|
||||
hidden_states = torch.tensor(output[0]).float()
|
||||
|
||||
torch.save(
|
||||
hidden_states,
|
||||
f"{working_dir}/hidden_states_{int(layer_name) + 1}.pt",
|
||||
)
|
||||
|
||||
elif layer_name == "end":
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/ln_f.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/ln_f",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/ln_f.vmfb")
|
||||
|
||||
hidden_states = torch.load(f"{working_dir}/hidden_states_{n_layer}.pt")
|
||||
|
||||
hidden_states = shark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
mlir_str = ""
|
||||
|
||||
if will_compile:
|
||||
f = open(f"{working_dir}/lm_head.mlir", encoding="utf-8")
|
||||
mlir_str = f.read()
|
||||
f.close()
|
||||
|
||||
mlir_str = bytes(mlir_str, "utf-8")
|
||||
|
||||
if config["n_embed"] == 14336:
|
||||
|
||||
def get_state_dict():
|
||||
d = torch.load(
|
||||
f"{working_dir}/pytorch_model_00001-of-00072.bin"
|
||||
)
|
||||
return OrderedDict(
|
||||
(k.replace("word_embeddings.", ""), v)
|
||||
for k, v in d.items()
|
||||
)
|
||||
|
||||
def load_causal_lm_head():
|
||||
linear = nn.utils.skip_init(
|
||||
nn.Linear, 14336, 250880, bias=False, dtype=torch.float
|
||||
)
|
||||
linear.load_state_dict(get_state_dict(), strict=False)
|
||||
return linear.float()
|
||||
|
||||
lm_head = load_causal_lm_head()
|
||||
|
||||
logits = lm_head(torch.tensor(hidden_states).float())
|
||||
|
||||
else:
|
||||
shark_module = SharkInference(
|
||||
mlir_str,
|
||||
device="cpu",
|
||||
mlir_dialect="tm_tensor",
|
||||
device_idx=None,
|
||||
)
|
||||
|
||||
if will_compile:
|
||||
shark_module.save_module(
|
||||
module_name=f"{working_dir}/lm_head",
|
||||
extra_args=[
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-stream-resource-max-allocation-size=1000000000",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
],
|
||||
)
|
||||
|
||||
shark_module.load_module(f"{working_dir}/lm_head.vmfb")
|
||||
|
||||
logits = shark_module(
|
||||
inputs=(hidden_states,), function_name="forward"
|
||||
)
|
||||
|
||||
logits = torch.tensor(logits).float()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(working_dir)
|
||||
|
||||
next_token = tokenizer.decode(torch.argmax(logits[:, -1, :], dim=-1))
|
||||
|
||||
f = open(f"{working_dir}/prompt.txt", "w+")
|
||||
f.write(prompt + next_token)
|
||||
f.close()
|
||||
390
shark/examples/shark_inference/simple_dlrm.py
Normal file
390
shark/examples/shark_inference/simple_dlrm.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Description: an implementation of a deep learning recommendation model (DLRM)
|
||||
# The model input consists of dense and sparse features. The former is a vector
|
||||
# of floating point values. The latter is a list of sparse indices into
|
||||
# embedding tables, which consist of vectors of floating point values.
|
||||
# The selected vectors are passed to mlp networks denoted by triangles,
|
||||
# in some cases the vectors are interacted through operators (Ops).
|
||||
#
|
||||
# output:
|
||||
# vector of values
|
||||
# model: |
|
||||
# /\
|
||||
# /__\
|
||||
# |
|
||||
# _____________________> Op <___________________
|
||||
# / | \
|
||||
# /\ /\ /\
|
||||
# /__\ /__\ ... /__\
|
||||
# | | |
|
||||
# | Op Op
|
||||
# | ____/__\_____ ____/__\____
|
||||
# | |_Emb_|____|__| ... |_Emb_|__|___|
|
||||
# input:
|
||||
# [ dense features ] [sparse indices] , ..., [sparse indices]
|
||||
#
|
||||
# More precise definition of model layers:
|
||||
# 1) fully connected layers of an mlp
|
||||
# z = f(y)
|
||||
# y = Wx + b
|
||||
#
|
||||
# 2) embedding lookup (for a list of sparse indices p=[p1,...,pk])
|
||||
# z = Op(e1,...,ek)
|
||||
# obtain vectors e1=E[:,p1], ..., ek=E[:,pk]
|
||||
#
|
||||
# 3) Operator Op can be one of the following
|
||||
# Sum(e1,...,ek) = e1 + ... + ek
|
||||
# Dot(e1,...,ek) = [e1'e1, ..., e1'ek, ..., ek'e1, ..., ek'ek]
|
||||
# Cat(e1,...,ek) = [e1', ..., ek']'
|
||||
# where ' denotes transpose operation
|
||||
#
|
||||
# References:
|
||||
# [1] Maxim Naumov, Dheevatsa Mudigere, Hao-Jun Michael Shi, Jianyu Huang,
|
||||
# Narayanan Sundaram, Jongsoo Park, Xiaodong Wang, Udit Gupta, Carole-Jean Wu,
|
||||
# Alisson G. Azzolini, Dmytro Dzhulgakov, Andrey Mallevich, Ilia Cherniavskii,
|
||||
# Yinghai Lu, Raghuraman Krishnamoorthi, Ansha Yu, Volodymyr Kondratenko,
|
||||
# Stephanie Pereira, Xianjie Chen, Wenlin Chen, Vijay Rao, Bill Jia, Liang Xiong,
|
||||
# Misha Smelyanskiy, "Deep Learning Recommendation Model for Personalization and
|
||||
# Recommendation Systems", CoRR, arXiv:1906.00091, 2019
|
||||
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
### define dlrm in PyTorch ###
|
||||
class DLRM_Net(nn.Module):
|
||||
def create_mlp(self, ln, sigmoid_layer):
|
||||
# build MLP layer by layer
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, ln.size - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
# construct fully connected operator
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
# initialize the weights
|
||||
# with torch.no_grad():
|
||||
# custom Xavier input, output or two-sided fill
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
|
||||
# approach 2
|
||||
# LL.weight.data.copy_(torch.tensor(W))
|
||||
# LL.bias.data.copy_(torch.tensor(bt))
|
||||
# approach 3
|
||||
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
# construct sigmoid or relu operator
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# approach 1: use ModuleList
|
||||
# return layers
|
||||
# approach 2: use Sequential container to wrap all layers
|
||||
return torch.nn.Sequential(*layers)
|
||||
|
||||
def create_emb(self, m, ln, weighted_pooling=None):
|
||||
emb_l = nn.ModuleList()
|
||||
v_W_l = []
|
||||
for i in range(0, ln.size):
|
||||
n = ln[i]
|
||||
|
||||
# construct embedding operator
|
||||
EE = nn.EmbeddingBag(n, m, mode="sum")
|
||||
# initialize embeddings
|
||||
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
|
||||
).astype(np.float32)
|
||||
# approach 1
|
||||
print(W)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
# approach 2
|
||||
# EE.weight.data.copy_(torch.tensor(W))
|
||||
# approach 3
|
||||
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
if weighted_pooling is None:
|
||||
v_W_l.append(None)
|
||||
else:
|
||||
v_W_l.append(torch.ones(n, dtype=torch.float32))
|
||||
emb_l.append(EE)
|
||||
return emb_l, v_W_l
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa=None,
|
||||
ln_emb=None,
|
||||
ln_bot=None,
|
||||
ln_top=None,
|
||||
arch_interaction_op=None,
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_bot=-1,
|
||||
sigmoid_top=-1,
|
||||
weighted_pooling=None,
|
||||
):
|
||||
super(DLRM_Net, self).__init__()
|
||||
|
||||
if (
|
||||
(m_spa is not None)
|
||||
and (ln_emb is not None)
|
||||
and (ln_bot is not None)
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
self.arch_interaction_itself = arch_interaction_itself
|
||||
if weighted_pooling is not None and weighted_pooling != "fixed":
|
||||
self.weighted_pooling = "learned"
|
||||
else:
|
||||
self.weighted_pooling = weighted_pooling
|
||||
|
||||
# create operators
|
||||
self.emb_l, w_list = self.create_emb(
|
||||
m_spa, ln_emb, weighted_pooling
|
||||
)
|
||||
if self.weighted_pooling == "learned":
|
||||
self.v_W_l = nn.ParameterList()
|
||||
for w in w_list:
|
||||
self.v_W_l.append(nn.Parameter(w))
|
||||
else:
|
||||
self.v_W_l = w_list
|
||||
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
|
||||
self.top_l = self.create_mlp(ln_top, sigmoid_top)
|
||||
|
||||
def apply_mlp(self, x, layers):
|
||||
return layers(x)
|
||||
|
||||
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
|
||||
# WARNING: notice that we are processing the batch at once. We implicitly
|
||||
# assume that the data is laid out such that:
|
||||
# 1. each embedding is indexed with a group of sparse indices,
|
||||
# corresponding to a single lookup
|
||||
# 2. for each embedding the lookups are further organized into a batch
|
||||
# 3. for a list of embedding tables there is a list of batched lookups
|
||||
# TORCH-MLIR
|
||||
# We are passing all the embeddings as arguments for easy parsing.
|
||||
|
||||
ly = []
|
||||
for k, sparse_index_group_batch in enumerate(lS_i):
|
||||
sparse_offset_group_batch = lS_o[k]
|
||||
|
||||
# embedding lookup
|
||||
# We are using EmbeddingBag, which implicitly uses sum operator.
|
||||
# The embeddings are represented as tall matrices, with sum
|
||||
# happening vertically across 0 axis, resulting in a row vector
|
||||
# E = emb_l[k]
|
||||
|
||||
if v_W_l[k] is not None:
|
||||
per_sample_weights = v_W_l[k].gather(
|
||||
0, sparse_index_group_batch
|
||||
)
|
||||
else:
|
||||
per_sample_weights = None
|
||||
|
||||
E = emb_l[k]
|
||||
V = E(
|
||||
sparse_index_group_batch,
|
||||
sparse_offset_group_batch,
|
||||
per_sample_weights=per_sample_weights,
|
||||
)
|
||||
|
||||
ly.append(V)
|
||||
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
# perform a dot product
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
# append dense feature with the interactions (into a row vector)
|
||||
# approach 1: all
|
||||
# Zflat = Z.view((batch_size, -1))
|
||||
# approach 2: unique
|
||||
_, ni, nj = Z.shape
|
||||
# approach 1: tril_indices
|
||||
# offset = 0 if self.arch_interaction_itself else -1
|
||||
# li, lj = torch.tril_indices(ni, nj, offset=offset)
|
||||
# approach 2: custom
|
||||
offset = 1 if self.arch_interaction_itself else 0
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for j in range(i + offset)]
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i + offset)]
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
# concatenate dense features and interactions
|
||||
R = torch.cat([x] + [Zflat], dim=1)
|
||||
elif self.arch_interaction_op == "cat":
|
||||
# concatenation features (into a row vector)
|
||||
R = torch.cat([x] + ly, dim=1)
|
||||
else:
|
||||
sys.exit(
|
||||
"ERROR: --arch-interaction-op="
|
||||
+ self.arch_interaction_op
|
||||
+ " is not supported"
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, dense_x, lS_o, *lS_i):
|
||||
return self.sequential_forward(dense_x, lS_o, lS_i)
|
||||
|
||||
def sequential_forward(self, dense_x, lS_o, lS_i):
|
||||
# process dense features (using bottom mlp), resulting in a row vector
|
||||
x = self.apply_mlp(dense_x, self.bot_l)
|
||||
# debug prints
|
||||
# print("intermediate")
|
||||
# print(x.detach().cpu().numpy())
|
||||
|
||||
# process sparse features(using embeddings), resulting in a list of row vectors
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
|
||||
# for y in ly:
|
||||
# print(y.detach().cpu().numpy())
|
||||
|
||||
# interact features (dense and sparse)
|
||||
z = self.interact_features(x, ly)
|
||||
# print(z.detach().cpu().numpy())
|
||||
|
||||
# obtain probability of a click (using top mlp)
|
||||
p = self.apply_mlp(z, self.top_l)
|
||||
|
||||
# # clamp output if needed
|
||||
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
|
||||
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
|
||||
# else:
|
||||
# z = p
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def dash_separated_ints(value):
|
||||
vals = value.split("-")
|
||||
for val in vals:
|
||||
try:
|
||||
int(val)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is not a valid dash separated list of ints" % value
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# model related parameters
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Deep Learning Recommendation Model (DLRM)"
|
||||
)
|
||||
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
# j will be replaced with the table number
|
||||
parser.add_argument(
|
||||
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-itself", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument("--weighted-pooling", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
|
||||
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
|
||||
m_den = ln_bot[0]
|
||||
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
|
||||
m_spa = args.arch_sparse_feature_size
|
||||
ln_emb = np.asarray(ln_emb)
|
||||
num_fea = ln_emb.size + 1 # num sparse + num dense features
|
||||
|
||||
|
||||
# Initialize the model.
|
||||
dlrm_model = DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op=args.arch_interaction_op,
|
||||
)
|
||||
|
||||
|
||||
# Inputs to the model.
|
||||
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
|
||||
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
|
||||
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
|
||||
|
||||
input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
golden_output = dlrm_model(dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(input_dlrm)
|
||||
np.testing.assert_allclose(
|
||||
golden_output.detach().numpy(), result, rtol=1e-02, atol=1e-03
|
||||
)
|
||||
|
||||
|
||||
# Verified via torch-mlir.
|
||||
# import torch_mlir
|
||||
# from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# dlrm_model, inputs, use_tracing=True, output_type="linalg-on-tensors"
|
||||
# )
|
||||
# backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
# compiled = backend.compile(module)
|
||||
# jit_module = backend.load(compiled)
|
||||
|
||||
# dense_numpy = dense_inp.numpy()
|
||||
# vs0_numpy = vs0.numpy()
|
||||
# vsi_numpy = [inp.numpy() for inp in vsi]
|
||||
|
||||
# numpy_inp = (dense_numpy, vs0_numpy, *vsi_numpy)
|
||||
|
||||
# print(jit_module.forward(*numpy_inp))
|
||||
311
shark/examples/shark_inference/sparse_arch.py
Normal file
311
shark/examples/shark_inference/sparse_arch.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchrec.datasets.utils import Batch
|
||||
from torchrec.modules.crossnet import LowRankCrossNet
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
|
||||
from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
||||
from torchrec.modules.embedding_modules import EmbeddingBagCollection
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from torchrec.models.dlrm import (
|
||||
choose,
|
||||
DenseArch,
|
||||
DLRM,
|
||||
InteractionArch,
|
||||
SparseArch,
|
||||
OverArch,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def calculate_offsets(tensor_list, prev_values, prev_offsets):
|
||||
offset_init = 0
|
||||
offset_list = []
|
||||
values_list = []
|
||||
|
||||
if prev_offsets != None:
|
||||
offset_init = prev_values.shape[-1]
|
||||
for tensor in tensor_list:
|
||||
offset_list.append(offset_init)
|
||||
offset_init += tensor.shape[0]
|
||||
|
||||
concatendated_tensor_list = torch.cat(tensor_list)
|
||||
|
||||
if prev_values != None:
|
||||
concatendated_tensor_list = torch.cat(
|
||||
[prev_values, concatendated_tensor_list]
|
||||
)
|
||||
|
||||
concatenated_offsets = torch.tensor(offset_list)
|
||||
|
||||
if prev_offsets != None:
|
||||
concatenated_offsets = torch.cat([prev_offsets, concatenated_offsets])
|
||||
|
||||
return concatendated_tensor_list, concatenated_offsets
|
||||
|
||||
|
||||
# Have to make combined_keys as dict as to which embedding bags they
|
||||
# point to. {f1: 0, f3: 0, f2: 1}
|
||||
# The result will be a triple containing values, indices and pointer tensor.
|
||||
def to_list(key_jagged, combined_keys):
|
||||
key_jagged_dict = key_jagged.to_dict()
|
||||
combined_list = []
|
||||
|
||||
for key in combined_keys:
|
||||
prev_values, prev_offsets = calculate_offsets(
|
||||
key_jagged_dict[key].to_dense(), None, None
|
||||
)
|
||||
print(prev_values)
|
||||
print(prev_offsets)
|
||||
combined_list.append(prev_values)
|
||||
combined_list.append(prev_offsets)
|
||||
combined_list.append(torch.tensor(combined_keys[key]))
|
||||
|
||||
return combined_list
|
||||
|
||||
|
||||
class SparseArchShark(nn.Module):
|
||||
def create_emb(self, embedding_dim, num_embeddings_list):
|
||||
embedding_list = nn.ModuleList()
|
||||
for i in range(0, num_embeddings_list.size):
|
||||
num_embeddings = num_embeddings_list[i]
|
||||
EE = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum")
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / num_embeddings),
|
||||
high=np.sqrt(1 / num_embeddings),
|
||||
size=(num_embeddings, embedding_dim),
|
||||
).astype(np.float32)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
embedding_list.append(EE)
|
||||
return embedding_list
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
):
|
||||
super(SparseArchShark, self).__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_features = total_features
|
||||
self.embedding_list = self.create_emb(
|
||||
embedding_dim, num_embeddings_list
|
||||
)
|
||||
|
||||
def forward(self, *batched_inputs):
|
||||
concatenated_list = []
|
||||
input_enum, embedding_enum = 0, 0
|
||||
|
||||
for k in range(len(batched_inputs) // 3):
|
||||
values = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
offsets = batched_inputs[input_enum]
|
||||
input_enum += 1
|
||||
embedding_pointer = int(batched_inputs[input_enum])
|
||||
input_enum += 1
|
||||
|
||||
E = self.embedding_list[embedding_pointer]
|
||||
V = E(values, offsets)
|
||||
concatenated_list.append(V)
|
||||
|
||||
return torch.cat(concatenated_list, dim=1).reshape(
|
||||
-1, self.num_features, self.embedding_dim
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_arch() -> None:
|
||||
D = 3
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=10,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_arch = SparseArch(ebc)
|
||||
|
||||
keys = ["f1", "f2", "f3", "f4", "f5"]
|
||||
offsets = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 19])
|
||||
features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=keys,
|
||||
values=torch.tensor(
|
||||
[1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]
|
||||
),
|
||||
offsets=offsets,
|
||||
)
|
||||
sparse_archi = SparseArchShark(D, 3, np.array([10, 10]))
|
||||
sparse_archi.embedding_list[0].weight = w1
|
||||
sparse_archi.embedding_list[1].weight = w2
|
||||
inputs = to_list(features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
test_results = sparse_archi(*inputs)
|
||||
sparse_features = sparse_arch(features)
|
||||
|
||||
torch.allclose(
|
||||
sparse_features,
|
||||
test_results,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_sparse_arch()
|
||||
|
||||
|
||||
class DLRMShark(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim,
|
||||
total_features,
|
||||
num_embeddings_list,
|
||||
dense_in_features: int,
|
||||
dense_arch_layer_sizes: List[int],
|
||||
over_arch_layer_sizes: List[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sparse_arch: SparseArchShark = SparseArchShark(
|
||||
embedding_dim, total_features, num_embeddings_list
|
||||
)
|
||||
num_sparse_features: int = total_features
|
||||
|
||||
self.dense_arch = DenseArch(
|
||||
in_features=dense_in_features,
|
||||
layer_sizes=dense_arch_layer_sizes,
|
||||
)
|
||||
|
||||
self.inter_arch = InteractionArch(
|
||||
num_sparse_features=num_sparse_features,
|
||||
)
|
||||
|
||||
over_in_features: int = (
|
||||
embedding_dim
|
||||
+ choose(num_sparse_features, 2)
|
||||
+ num_sparse_features
|
||||
)
|
||||
|
||||
self.over_arch = OverArch(
|
||||
in_features=over_in_features,
|
||||
layer_sizes=over_arch_layer_sizes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, dense_features: torch.Tensor, *sparse_features
|
||||
) -> torch.Tensor:
|
||||
embedded_dense = self.dense_arch(dense_features)
|
||||
embedded_sparse = self.sparse_arch(*sparse_features)
|
||||
concatenated_dense = self.inter_arch(
|
||||
dense_features=embedded_dense, sparse_features=embedded_sparse
|
||||
)
|
||||
logits = self.over_arch(concatenated_dense)
|
||||
return logits
|
||||
|
||||
|
||||
def test_dlrm() -> None:
|
||||
B = 2
|
||||
D = 8
|
||||
dense_in_features = 100
|
||||
|
||||
eb1_config = EmbeddingBagConfig(
|
||||
name="t1",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f1", "f3"],
|
||||
)
|
||||
eb2_config = EmbeddingBagConfig(
|
||||
name="t2",
|
||||
embedding_dim=D,
|
||||
num_embeddings=100,
|
||||
feature_names=["f2"],
|
||||
)
|
||||
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
|
||||
sparse_features = KeyedJaggedTensor.from_offsets_sync(
|
||||
keys=["f1", "f3", "f2"],
|
||||
values=torch.tensor([1, 2, 4, 5, 4, 3, 2, 9, 1, 2, 3]),
|
||||
offsets=torch.tensor([0, 2, 4, 6, 8, 10, 11]),
|
||||
)
|
||||
ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
|
||||
sparse_nn = DLRM(
|
||||
embedding_bag_collection=ebc,
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
sparse_nn_nod = DLRMShark(
|
||||
embedding_dim=8,
|
||||
total_features=3,
|
||||
num_embeddings_list=np.array([100, 100]),
|
||||
dense_in_features=dense_in_features,
|
||||
dense_arch_layer_sizes=[20, D],
|
||||
over_arch_layer_sizes=[5, 1],
|
||||
)
|
||||
|
||||
dense_features = torch.rand((B, dense_in_features))
|
||||
|
||||
x = to_list(sparse_features, {"f1": 0, "f3": 0, "f2": 1})
|
||||
|
||||
w1 = ebc.embedding_bags["t1"].weight
|
||||
w2 = ebc.embedding_bags["t2"].weight
|
||||
|
||||
sparse_nn_nod.sparse_arch.embedding_list[0].weight = w1
|
||||
sparse_nn_nod.sparse_arch.embedding_list[1].weight = w2
|
||||
|
||||
sparse_nn_nod.dense_arch.load_state_dict(sparse_nn.dense_arch.state_dict())
|
||||
sparse_nn_nod.inter_arch.load_state_dict(sparse_nn.inter_arch.state_dict())
|
||||
sparse_nn_nod.over_arch.load_state_dict(sparse_nn.over_arch.state_dict())
|
||||
|
||||
logits = sparse_nn(
|
||||
dense_features=dense_features,
|
||||
sparse_features=sparse_features,
|
||||
)
|
||||
logits_nod = sparse_nn_nod(dense_features, *x)
|
||||
|
||||
# print(logits)
|
||||
# print(logits_nod)
|
||||
|
||||
# Import the module and print.
|
||||
mlir_importer = SharkImporter(
|
||||
sparse_nn_nod,
|
||||
(dense_features, *x),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(dlrm_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=True
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
dlrm_mlir, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
|
||||
torch.allclose(
|
||||
logits,
|
||||
logits_nod,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
|
||||
test_dlrm()
|
||||
35
shark/examples/shark_inference/t5_tf.py
Normal file
35
shark/examples/shark_inference/t5_tf.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
from transformers import T5Tokenizer, TFT5Model
|
||||
import tensorflow as tf
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# Create a set of inputs
|
||||
t5_inputs = [
|
||||
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[1, 10], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class T5Module(tf.Module):
|
||||
def __init__(self):
|
||||
super(T5Module, self).__init__()
|
||||
self.m = TFT5Model.from_pretrained("t5-small")
|
||||
self.m.predict = lambda x, y: self.m(input_ids=x, decoder_input_ids=y)
|
||||
|
||||
@tf.function(input_signature=t5_inputs, jit_compile=True)
|
||||
def forward(self, input_ids, decoder_input_ids):
|
||||
return self.m.predict(input_ids, decoder_input_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Prepping Data
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
text = "I love the distilled version of models."
|
||||
inputs = tokenizer(text, return_tensors="tf").input_ids
|
||||
|
||||
shark_module = SharkInference(T5Module(), (inputs, inputs))
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward((inputs, inputs)))
|
||||
43
shark/examples/shark_inference/torch_vision_models_script.py
Normal file
43
shark/examples/shark_inference/torch_vision_models_script.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
class VisionModule(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.train(False)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model.forward(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
## The vision models present here: https://pytorch.org/vision/stable/models.html
|
||||
vision_models_list = [
|
||||
models.resnet18(pretrained=True),
|
||||
models.alexnet(pretrained=True),
|
||||
models.vgg16(pretrained=True),
|
||||
models.squeezenet1_0(pretrained=True),
|
||||
models.densenet161(pretrained=True),
|
||||
models.inception_v3(pretrained=True),
|
||||
models.shufflenet_v2_x1_0(pretrained=True),
|
||||
models.mobilenet_v2(pretrained=True),
|
||||
models.mobilenet_v3_small(pretrained=True),
|
||||
models.resnext50_32x4d(pretrained=True),
|
||||
models.wide_resnet50_2(pretrained=True),
|
||||
models.mnasnet1_0(pretrained=True),
|
||||
models.efficientnet_b0(pretrained=True),
|
||||
models.regnet_y_400mf(pretrained=True),
|
||||
models.regnet_x_400mf(pretrained=True),
|
||||
]
|
||||
|
||||
for i, vision_model in enumerate(vision_models_list):
|
||||
shark_module = SharkInference(
|
||||
VisionModule(vision_model),
|
||||
(input,),
|
||||
)
|
||||
shark_module.compile()
|
||||
shark_module.forward((input,))
|
||||
39
shark/examples/shark_inference/unet_script.py
Normal file
39
shark/examples/shark_inference/unet_script.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
|
||||
|
||||
class UnetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = torch.hub.load(
|
||||
"mateuszbuda/brain-segmentation-pytorch",
|
||||
"unet",
|
||||
in_channels=3,
|
||||
out_channels=1,
|
||||
init_features=32,
|
||||
pretrained=True,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
input = torch.randn(1, 3, 224, 224)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
UnetModule(),
|
||||
(input,),
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
(vision_mlir, func_name), inputs, golden_out = mlir_importer.import_debug(
|
||||
tracing_required=False
|
||||
)
|
||||
|
||||
shark_module = SharkInference(vision_mlir, mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input,))
|
||||
np.testing.assert_allclose(golden_out, result, rtol=1e-02, atol=1e-03)
|
||||
21
shark/examples/shark_inference/upscaler/main.py
Normal file
21
shark/examples/shark_inference/upscaler/main.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from pipeline_shark_stable_diffusion_upscale import (
|
||||
SharkStableDiffusionUpscalePipeline,
|
||||
)
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = SharkStableDiffusionUpscalePipeline(model_id)
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
response = requests.get(url)
|
||||
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
low_res_img = low_res_img.resize((128, 128))
|
||||
|
||||
prompt = "a white cat"
|
||||
|
||||
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
|
||||
upscaled_image.save("upsampled_cat.png")
|
||||
98
shark/examples/shark_inference/upscaler/model_wrappers.py
Normal file
98
shark/examples/shark_inference/upscaler/model_wrappers.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
import torch
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
|
||||
model_input = {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 128, 128),),
|
||||
"unet": (
|
||||
torch.randn(2, 7, 128, 128), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.randn(2).to(torch.int64), # noise_level
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
)
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
model_input["clip"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae_mlir(model_name="vae", extra_args=[]):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.decode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
vae = VaeModel()
|
||||
shark_vae = compile_through_fx(
|
||||
vae,
|
||||
model_input["vae"],
|
||||
model_name=model_name,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_vae
|
||||
|
||||
|
||||
def get_unet_mlir(model_name="unet", extra_args=[]):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
)
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(self, latent, timestep, text_embedding, noise_level):
|
||||
unet_out = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
noise_level,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return unet_out
|
||||
|
||||
unet = UnetModel()
|
||||
f16_input_mask = (True, True, True, False)
|
||||
shark_unet = compile_through_fx(
|
||||
unet,
|
||||
model_input["unet"],
|
||||
model_name=model_name,
|
||||
is_f16=True,
|
||||
f16_input_mask=f16_input_mask,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
return shark_unet
|
||||
48
shark/examples/shark_inference/upscaler/opt_params.py
Normal file
48
shark/examples/shark_inference/upscaler/opt_params.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import sys
|
||||
from model_wrappers import (
|
||||
get_vae_mlir,
|
||||
get_unet_mlir,
|
||||
get_clip_mlir,
|
||||
)
|
||||
from upscaler_args import args
|
||||
from utils import get_shark_model
|
||||
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
if BATCH_SIZE != 1:
|
||||
sys.exit("Only batch size 1 is supported.")
|
||||
|
||||
|
||||
unet_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
|
||||
vae_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
clip_flag = [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
|
||||
bucket = "gs://shark_tank/stable_diffusion/"
|
||||
|
||||
|
||||
def get_unet():
|
||||
model_name = "upscaler_unet"
|
||||
if args.import_mlir:
|
||||
return get_unet_mlir(model_name, unet_flag)
|
||||
return get_shark_model(bucket, model_name, unet_flag)
|
||||
|
||||
|
||||
def get_vae():
|
||||
model_name = "upscaler_vae"
|
||||
if args.import_mlir:
|
||||
return get_vae_mlir(model_name, vae_flag)
|
||||
return get_shark_model(bucket, model_name, vae_flag)
|
||||
|
||||
|
||||
def get_clip():
|
||||
model_name = "upscaler_clip"
|
||||
if args.import_mlir:
|
||||
return get_clip_mlir(model_name, clip_flag)
|
||||
return get_shark_model(bucket, model_name, clip_flag)
|
||||
@@ -0,0 +1,489 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers import logging
|
||||
from diffusers.pipeline_utils import ImagePipelineOutput
|
||||
from opt_params import get_unet, get_vae, get_clip
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def shark_run_wrapper(model, *args):
|
||||
np_inputs = tuple([x.detach().numpy() for x in args])
|
||||
outputs = model("forward", np_inputs)
|
||||
return torch.from_numpy(outputs)
|
||||
|
||||
|
||||
class SharkStableDiffusionUpscalePipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id,
|
||||
):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model_id, subfolder="tokenizer"
|
||||
)
|
||||
self.low_res_scheduler = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
self.vae = get_vae()
|
||||
self.unet = get_unet()
|
||||
self.text_encoder = get_clip()
|
||||
self.max_noise_level = (350,)
|
||||
self._execution_device = "cpu"
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = text_inputs.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
text_embeddings = shark_run_wrapper(
|
||||
self.text_encoder, text_input_ids.to(device)
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if (
|
||||
# hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
# and self.text_encoder.config.use_attention_mask
|
||||
# ):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = shark_run_wrapper(
|
||||
self.text_encoder,
|
||||
uncond_input.input_ids.to(device),
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = shark_run_wrapper(self.vae, latents)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, image, noise_level, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
if images.shape[-1] == 1:
|
||||
# special case for grayscale (single channel) images
|
||||
pil_images = [
|
||||
Image.fromarray(image.squeeze(), mode="L") for image in images
|
||||
]
|
||||
else:
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device="cpu", dtype=dtype
|
||||
).to(device)
|
||||
else:
|
||||
latents = torch.randn(
|
||||
shape, generator=generator, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected {shape}"
|
||||
)
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]
|
||||
],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[
|
||||
Union[torch.Generator, List[torch.Generator]]
|
||||
] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[
|
||||
Callable[[int, int, torch.FloatTensor], None]
|
||||
] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor(
|
||||
[noise_level], dtype=torch.long, device=device
|
||||
)
|
||||
if device == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=text_embeddings.dtype,
|
||||
).to(device)
|
||||
else:
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=text_embeddings.dtype,
|
||||
)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
# num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_latents = 4
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
# if (
|
||||
# num_channels_latents + num_channels_image
|
||||
# != self.unet.config.in_channels
|
||||
# ):
|
||||
# raise ValueError(
|
||||
# f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
# f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
# f" `num_channels_image`: {num_channels_image} "
|
||||
# f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
# " `pipeline.unet` or your `image` input."
|
||||
# )
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = (
|
||||
len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
)
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
timestep = torch.tensor([t]).to(torch.float32)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = shark_run_wrapper(
|
||||
self.unet,
|
||||
latent_model_input.half(),
|
||||
timestep,
|
||||
text_embeddings.half(),
|
||||
noise_level,
|
||||
)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# # call the callback, if provided
|
||||
# if i == len(timesteps) - 1 or (
|
||||
# (i + 1) > num_warmup_steps
|
||||
# and (i + 1) % self.scheduler.order == 0
|
||||
# ):
|
||||
# progress_bar.update()
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
98
shark/examples/shark_inference/upscaler/upscaler_args.py
Normal file
98
shark/examples/shark_inference/upscaler/upscaler_args.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Stable Diffusion Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--prompts",
|
||||
nargs="+",
|
||||
default=["cyberpunk forest by Salvador Dali"],
|
||||
help="text of which images to be generated.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative-prompts",
|
||||
nargs="+",
|
||||
default=[""],
|
||||
help="text you don't want to see in the generated image.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the no. of steps to do the sampling.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed to use.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="the value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### Model Config and Usage Params
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--device", type=str, default="vulkan", help="device to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--precision", type=str, default="fp16", help="precision to run the model."
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--import_mlir",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="imports the model from torch module to shark_module otherwise downloads the model from shark_tank.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--load_vmfb",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="attempts to load the model from a precompiled flatbuffer and compiles + saves it if not found.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--save_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="saves the compiled flatbuffer to the local directory",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree-vulkan-target-triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
230
shark/examples/shark_inference/upscaler/utils.py
Normal file
230
shark/examples/shark_inference/upscaler/utils.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import torch
|
||||
from shark.shark_inference import SharkInference
|
||||
from upscaler_args import args
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
if args.load_vmfb or args.save_vmfb:
|
||||
device = (
|
||||
args.device
|
||||
if "://" not in args.device
|
||||
else "-".join(args.device.split("://"))
|
||||
)
|
||||
extended_name = "{}_{}".format(model_name, device)
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_name + ".vmfb")
|
||||
if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb:
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=extra_args)
|
||||
else:
|
||||
if args.save_vmfb:
|
||||
print("Saving to {}".format(vmfb_path))
|
||||
else:
|
||||
print(
|
||||
"No vmfb found. Compiling and saving to {}".format(
|
||||
vmfb_path
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
shark_module.compile(extra_args)
|
||||
return shark_module
|
||||
|
||||
|
||||
# Downloads the model from shark_tank and returns the shark_module.
|
||||
def get_shark_model(tank_url, model_name, extra_args=[]):
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
# Set local shark_tank cache directory.
|
||||
# shark_args.local_tank_cache = args.local_tank_cache
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
model_name,
|
||||
tank_url=tank_url,
|
||||
frontend="torch",
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
# Converts the torch-module into a shark_module.
|
||||
def compile_through_fx(
|
||||
model, inputs, model_name, is_f16=False, f16_input_mask=None, extra_args=[]
|
||||
):
|
||||
mlir_module, func_name = import_with_fx(
|
||||
model, inputs, is_f16, f16_input_mask
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
|
||||
return _compile_module(shark_module, model_name, extra_args)
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
f"--vulkan_debug_utils=true",
|
||||
]
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
def get_all_devices(driver_name):
|
||||
"""
|
||||
Inputs: driver_name
|
||||
Returns a list of all the available devices for a given driver sorted by
|
||||
the iree path names of the device as in --list_devices option in iree.
|
||||
"""
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver(driver_name)
|
||||
device_list_src = driver.query_available_devices()
|
||||
device_list_src.sort(key=lambda d: d["path"])
|
||||
return device_list_src
|
||||
|
||||
|
||||
def get_device_mapping(driver, key_combination=3):
|
||||
"""This method ensures consistent device ordering when choosing
|
||||
specific devices for execution
|
||||
Args:
|
||||
driver (str): execution driver (vulkan, cuda, rocm, etc)
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Returns:
|
||||
dict: map to possible device names user can input mapped to desired combination of name/path.
|
||||
"""
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
driver = iree_device_map(driver)
|
||||
device_list = get_all_devices(driver)
|
||||
device_map = dict()
|
||||
|
||||
def get_output_value(dev_dict):
|
||||
if key_combination == 1:
|
||||
return f"{driver}://{dev_dict['path']}"
|
||||
if key_combination == 2:
|
||||
return dev_dict["name"]
|
||||
if key_combination == 3:
|
||||
return (dev_dict["name"], f"{driver}://{dev_dict['path']}")
|
||||
|
||||
# mapping driver name to default device (driver://0)
|
||||
device_map[f"{driver}"] = get_output_value(device_list[0])
|
||||
for i, device in enumerate(device_list):
|
||||
# mapping with index
|
||||
device_map[f"{driver}://{i}"] = get_output_value(device)
|
||||
# mapping with full path
|
||||
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
|
||||
return device_map
|
||||
|
||||
|
||||
def map_device_to_name_path(device, key_combination=3):
|
||||
"""Gives the appropriate device data (supported name/path) for user selected execution device
|
||||
Args:
|
||||
device (str): user
|
||||
key_combination (int, optional): choice for mapping value for device name.
|
||||
1 : path
|
||||
2 : name
|
||||
3 : (name, path)
|
||||
Defaults to 3.
|
||||
Raises:
|
||||
ValueError:
|
||||
Returns:
|
||||
str / tuple: returns the mapping str or tuple of mapping str for the device depending on key_combination value
|
||||
"""
|
||||
driver = device.split("://")[0]
|
||||
device_map = get_device_mapping(driver, key_combination)
|
||||
try:
|
||||
device_mapping = device_map[device]
|
||||
except KeyError:
|
||||
raise ValueError(f"Device '{device}' is not a valid device.")
|
||||
return device_mapping
|
||||
|
||||
|
||||
def set_init_device_flags():
|
||||
if "vulkan" in args.device:
|
||||
# set runtime flags for vulkan.
|
||||
set_iree_runtime_flags()
|
||||
|
||||
# set triple flag to avoid multiple calls to get_vulkan_triple_flag
|
||||
device_name, args.device = map_device_to_name_path(args.device)
|
||||
if not args.iree_vulkan_target_triple:
|
||||
triple = get_vulkan_target_triple(device_name)
|
||||
if triple is not None:
|
||||
args.iree_vulkan_target_triple = triple
|
||||
print(
|
||||
f"Found device {device_name}. Using target triple {args.iree_vulkan_target_triple}."
|
||||
)
|
||||
elif "cuda" in args.device:
|
||||
args.device = "cuda"
|
||||
elif "cpu" in args.device:
|
||||
args.device = "cpu"
|
||||
|
||||
# set max_length based on availability.
|
||||
if args.variant in ["anythingv3", "analogdiffusion", "dreamlike"]:
|
||||
args.max_length = 77
|
||||
elif args.variant == "openjourney":
|
||||
args.max_length = 64
|
||||
|
||||
# use tuned models only in the case of stablediffusion/fp16 and rdna3 cards.
|
||||
if (
|
||||
args.variant in ["openjourney", "dreamlike"]
|
||||
or args.precision != "fp16"
|
||||
or "vulkan" not in args.device
|
||||
or "rdna3" not in args.iree_vulkan_target_triple
|
||||
):
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
elif args.use_base_vae and args.variant != "stablediffusion":
|
||||
args.use_tuned = False
|
||||
print("Tuned models are currently not supported for this setting.")
|
||||
|
||||
if args.use_tuned:
|
||||
print("Using tuned models for stablediffusion/fp16 and rdna3 card.")
|
||||
|
||||
|
||||
# Utility to get list of devices available.
|
||||
def get_available_devices():
|
||||
def get_devices_by_name(driver_name):
|
||||
from shark.iree_utils._common import iree_device_map
|
||||
|
||||
device_list = []
|
||||
try:
|
||||
driver_name = iree_device_map(driver_name)
|
||||
device_list_dict = get_all_devices(driver_name)
|
||||
print(f"{driver_name} devices are available.")
|
||||
except:
|
||||
print(f"{driver_name} devices are not available.")
|
||||
else:
|
||||
for i, device in enumerate(device_list_dict):
|
||||
device_list.append(f"{driver_name}://{i} => {device['name']}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
available_devices = []
|
||||
vulkan_devices = get_devices_by_name("vulkan")
|
||||
available_devices.extend(vulkan_devices)
|
||||
cuda_devices = get_devices_by_name("cuda")
|
||||
available_devices.extend(cuda_devices)
|
||||
available_devices.append("cpu")
|
||||
return available_devices
|
||||
15
shark/examples/shark_inference/v_diffusion.py
Normal file
15
shark/examples/shark_inference/v_diffusion.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
|
||||
|
||||
mlir_model, func_name, inputs, golden_out = download_model(
|
||||
"v_diffusion", frontend="torch"
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device="vulkan", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
print("The golden result is:", golden_out)
|
||||
48
shark/examples/shark_training/bert_training.py
Normal file
48
shark/examples/shark_training/bert_training.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch.nn.utils import stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
class MiniLMSequenceClassification(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
|
||||
num_labels=2, # The number of output labels--2 for binary classification.
|
||||
output_attentions=False, # Whether the model returns attentions weights.
|
||||
output_hidden_states=False, # Whether the model returns all hidden-states.
|
||||
torchscript=True,
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
return self.model.forward(tokens)[0]
|
||||
|
||||
|
||||
mod = MiniLMSequenceClassification()
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
print(dict(mod.named_buffers()))
|
||||
|
||||
inp = (torch.randint(2, (1, 128)),)
|
||||
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
shark_module.compile(forward)
|
||||
shark_module.train(num_iters=2)
|
||||
print("training done")
|
||||
60
shark/examples/shark_training/bert_training_load_tf.py
Normal file
60
shark/examples/shark_training/bert_training_load_tf.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import tensorflow as tf
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
from shark.parser import parser
|
||||
from urllib import request
|
||||
|
||||
parser.add_argument(
|
||||
"--download_mlir_path",
|
||||
type=str,
|
||||
default="bert_tf_training.mlir",
|
||||
help="Specifies path to target mlir file that will be loaded.",
|
||||
)
|
||||
load_args, unknown = parser.parse_known_args()
|
||||
|
||||
tf.random.set_seed(0)
|
||||
vocab_size = 100
|
||||
NUM_CLASSES = 5
|
||||
SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
|
||||
# Download BERT model from tank and train.
|
||||
if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
file_link = "https://storage.googleapis.com/shark_tank/users/stanley/bert_tf_training.mlir"
|
||||
response = request.urlretrieve(file_link, load_args.download_mlir_path)
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
for val in predict_sample_input
|
||||
]
|
||||
num_iter = 10
|
||||
if not os.path.isfile(load_args.download_mlir_path):
|
||||
raise ValueError(
|
||||
f"Tried looking for target mlir in {load_args.download_mlir_path}, but cannot be found."
|
||||
)
|
||||
with open(load_args.download_mlir_path, "rb") as input_file:
|
||||
bert_mlir = input_file.read()
|
||||
shark_module = SharkTrainer(
|
||||
bert_mlir,
|
||||
(
|
||||
sample_input_tensors,
|
||||
tf.convert_to_tensor(
|
||||
np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32
|
||||
),
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("mhlo")
|
||||
shark_module.compile()
|
||||
start = time.time()
|
||||
print(shark_module.train(num_iter))
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
print("time/iter: " + str(total_time / num_iter))
|
||||
98
shark/examples/shark_training/bert_training_tf.py
Normal file
98
shark/examples/shark_training/bert_training_tf.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from absl import app
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from official.nlp.modeling import layers
|
||||
from official.nlp.modeling import networks
|
||||
from official.nlp.modeling.models import bert_classifier
|
||||
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
tf.random.set_seed(0)
|
||||
vocab_size = 100
|
||||
NUM_CLASSES = 5
|
||||
SEQUENCE_LENGTH = 512
|
||||
BATCH_SIZE = 1
|
||||
# Create a set of 2-dimensional inputs
|
||||
bert_input = [
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
tf.TensorSpec(shape=[BATCH_SIZE, SEQUENCE_LENGTH], dtype=tf.int32),
|
||||
]
|
||||
|
||||
|
||||
class BertModule(tf.Module):
|
||||
def __init__(self):
|
||||
super(BertModule, self).__init__()
|
||||
dict_outputs = False
|
||||
test_network = networks.BertEncoder(
|
||||
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs
|
||||
)
|
||||
|
||||
# Create a BERT trainer with the created network.
|
||||
bert_trainer_model = bert_classifier.BertClassifier(
|
||||
test_network, num_classes=NUM_CLASSES
|
||||
)
|
||||
bert_trainer_model.summary()
|
||||
|
||||
# Invoke the trainer model on the inputs. This causes the layer to be built.
|
||||
self.m = bert_trainer_model
|
||||
self.m.predict = lambda x: self.m.call(x, training=False)
|
||||
self.predict = tf.function(input_signature=[bert_input])(
|
||||
self.m.predict
|
||||
)
|
||||
self.m.learn = lambda x, y: self.m.call(x, training=False)
|
||||
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
|
||||
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
bert_input, # inputs
|
||||
tf.TensorSpec(shape=[BATCH_SIZE], dtype=tf.int32), # labels
|
||||
],
|
||||
jit_compile=True,
|
||||
)
|
||||
def forward(self, inputs, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
# Capture the gradients from forward prop...
|
||||
probs = self.m(inputs, training=True)
|
||||
loss = self.loss(labels, probs)
|
||||
|
||||
# ...and use them to update the model's weights.
|
||||
variables = self.m.trainable_variables
|
||||
gradients = tape.gradient(loss, variables)
|
||||
self.optimizer.apply_gradients(zip(gradients, variables))
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict_sample_input = [
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
np.random.randint(5, size=(BATCH_SIZE, SEQUENCE_LENGTH)),
|
||||
]
|
||||
sample_input_tensors = [
|
||||
tf.convert_to_tensor(val, dtype=tf.int32)
|
||||
for val in predict_sample_input
|
||||
]
|
||||
num_iter = 10
|
||||
shark_module = SharkTrainer(
|
||||
BertModule(),
|
||||
(
|
||||
sample_input_tensors,
|
||||
tf.convert_to_tensor(
|
||||
np.random.randint(5, size=(BATCH_SIZE)), dtype=tf.int32
|
||||
),
|
||||
),
|
||||
)
|
||||
shark_module.set_frontend("tensorflow")
|
||||
shark_module.compile()
|
||||
start = time.time()
|
||||
print(shark_module.train(num_iter))
|
||||
end = time.time()
|
||||
total_time = end - start
|
||||
print("time: " + str(total_time))
|
||||
print("time/iter: " + str(total_time / num_iter))
|
||||
44
shark/examples/shark_training/neural_net_training.py
Normal file
44
shark/examples/shark_training/neural_net_training.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch.nn.utils import _stateless
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Foo, self).__init__()
|
||||
self.l1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.l2 = torch.nn.Linear(16, 2)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.l1(x)
|
||||
out = self.relu(out)
|
||||
out = self.l2(out)
|
||||
return out
|
||||
|
||||
|
||||
mod = Foo()
|
||||
inp = (torch.randn(10, 10),)
|
||||
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
_stateless.functional_call(
|
||||
mod, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
|
||||
# fx_graph = forward(dict(mod.named_parameters()), dict(mod.named_buffers()), inp)
|
||||
|
||||
shark_module = SharkTrainer(mod, inp)
|
||||
# Pass the training function in case of torch
|
||||
shark_module.compile(training_fn=forward)
|
||||
|
||||
shark_module.train(num_iters=10)
|
||||
@@ -0,0 +1,41 @@
|
||||
# Stable Diffusion Img2Img model
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux)</summary>
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
### Install dependencies
|
||||
|
||||
# Run the setup.sh script
|
||||
|
||||
```shell
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
### Run the Stable diffusion Img2Img model
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_img2img.py
|
||||
```
|
||||
To run the model with your set of images, and parameters you need to specify the following params:
|
||||
1.) Input images directory with the arg `--input_dir` containing 3-5 images.
|
||||
2.) What to teach the model? Using the arg `--what_to_teach`, allowed values are `object` or `style`.
|
||||
3.) Placeholder token using the arg `--placeholder_token`, that represents your new concept. It should be passed with the opening and closing angle brackets. For ex: token is `cat-toy`, it should be passed as `<cat-toy>`.
|
||||
4.) Initializer token using the arg `--initializer_token`, which summarise what is your new concept.
|
||||
|
||||
For the result, you need to pass the text prompt with the arg: `--prompt`. The prompt string should contain a "*s" in it, which will be replaced by the placeholder token during the inference.
|
||||
|
||||
By default the result images will go into the `sd_result` dir. To specify your output dir use the arg: `--output_dir`.
|
||||
|
||||
The default value of max_training_steps is `3000`, which takes some hours to complete. You can pass the smaller value with the arg `--training_steps`. Specify the number of images to be sampled for the result with the `--num_inference_samples` arg.
|
||||
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
TD="$(cd $(dirname $0) && pwd)"
|
||||
if [ -z "$PYTHON" ]; then
|
||||
PYTHON="$(which python3)"
|
||||
fi
|
||||
|
||||
function die() {
|
||||
echo "Error executing command: $*"
|
||||
exit 1
|
||||
}
|
||||
|
||||
PYTHON_VERSION_X_Y=`${PYTHON} -c 'import sys; version=sys.version_info[:2]; print("{0}.{1}".format(*version))'`
|
||||
|
||||
echo "Python: $PYTHON"
|
||||
echo "Python version: $PYTHON_VERSION_X_Y"
|
||||
|
||||
mkdir input_images
|
||||
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg -P input_images/
|
||||
wget https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg -P input_images/
|
||||
|
||||
pip install diffusers["training"]==0.4.1 transformers ftfy opencv-python
|
||||
@@ -0,0 +1,600 @@
|
||||
# Textual-inversion fine-tuning for Stable Diffusion using diffusers
|
||||
# This script shows how to "teach" Stable Diffusion a new concept via
|
||||
# textual-inversion using 🤗 Hugging Face [🧨 Diffusers library](https://github.com/huggingface/diffusers).
|
||||
# By using just 3-5 images you can teach new concepts to Stable Diffusion
|
||||
# and personalize the model on your own images.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
YOUR_TOKEN = "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"
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="input_images/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=3000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
# Load input images.
|
||||
images = []
|
||||
for filename in os.listdir(args.input_dir):
|
||||
img = cv2.imread(os.path.join(args.input_dir, filename))
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
# Setting up the model
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
# Please read and if you agree accept the LICENSE
|
||||
# [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) if you see an error
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
use_auth_token=YOUR_TOKEN,
|
||||
)
|
||||
|
||||
# We have added the `placeholder_token` in the `tokenizer` so we resize the token embeddings here,
|
||||
# this will a new embedding vector in the token embeddings for our `placeholder_token`
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector,
|
||||
# so lets freeze rest of the model parameters here.
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
# Creating our training data
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.input_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=512,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training.
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
tensor_format="pt",
|
||||
)
|
||||
|
||||
# Define hyperparameters for our training
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"train_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"seed": args.seed,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
|
||||
|
||||
def training_function(text_encoder, vae, unet):
|
||||
logger = get_logger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
unet.eval()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* accelerator.num_processes
|
||||
* gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"])
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(
|
||||
latents, noise, timesteps
|
||||
)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, noise, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = (
|
||||
text_encoder.module.get_input_embeddings().weight.grad
|
||||
)
|
||||
else:
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = (
|
||||
torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
)
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
skip_prk_steps=True,
|
||||
),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker"
|
||||
),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained(
|
||||
"openai/clip-vit-base-patch32"
|
||||
),
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(
|
||||
learned_embeds_dict, os.path.join(output_dir, "learned_embeds.bin")
|
||||
)
|
||||
|
||||
|
||||
import accelerate
|
||||
|
||||
accelerate.notebook_launcher(
|
||||
training_function, args=(text_encoder, vae, unet), num_processes=1
|
||||
)
|
||||
|
||||
# Set up the pipeline
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
# torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
# output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
43
shark/examples/shark_training/stable_diffusion/README.md
Normal file
43
shark/examples/shark_training/stable_diffusion/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Stable Diffusion Fine Tuning
|
||||
|
||||
## Installation (Linux)
|
||||
|
||||
### Activate shark.venv Virtual Environment
|
||||
|
||||
```shell
|
||||
source shark.venv/bin/activate
|
||||
|
||||
# Some older pip installs may not be able to handle the recent PyTorch deps
|
||||
python -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
## Install dependencies
|
||||
|
||||
### Run the following installation commands:
|
||||
```
|
||||
pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
pip install accelerate transformers ftfy
|
||||
```
|
||||
|
||||
### Build torch-mlir with the following branch:
|
||||
|
||||
Please cherry-pick this branch of torch-mlir: https://github.com/vivekkhandelwal1/torch-mlir/tree/sd-ops
|
||||
and build it locally. You can find the instructions for using locally build Torch-MLIR,
|
||||
here: https://github.com/nod-ai/SHARK#how-to-use-your-locally-built-iree--torch-mlir-with-shark
|
||||
|
||||
## Run the Stable diffusion fine tuning
|
||||
|
||||
To run the model with the default set of images and params, run:
|
||||
```shell
|
||||
python stable_diffusion_fine_tuning.py
|
||||
```
|
||||
By default the training is run through the PyTorch path. If you want to train the model using the Torchdynamo path of Torch-MLIR, you need to specify `--use_torchdynamo=True`.
|
||||
|
||||
The default number of training steps are `2000`, which would take many hours to complete based on your system config. You can pass the smaller value with the arg `--training_steps`. You can specify the number of images to be sampled for the result with the `--num_inference_samples` arg. For the number of inference steps you can use `--inference_steps` flag.
|
||||
|
||||
For example, you can run the training for a limited set of steps via the dynamo path by using the following command:
|
||||
```
|
||||
python stable_diffusion_fine_tuning.py --training_steps=1 --inference_steps=1 --num_inference_samples=1 --train_batch_size=1 --use_torchdynamo=True
|
||||
```
|
||||
|
||||
You can also specify the device to be used via the flag `--device`. The default value is `cpu`, for GPU execution you can specify `--device="cuda"`.
|
||||
@@ -0,0 +1,914 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# Import required libraries
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
|
||||
# Enter your HuggingFace Token
|
||||
# Note: You can comment this prompt and just set your token instead of passing it through cli for every execution.
|
||||
hf_token = input("Please enter your huggingface token here: ")
|
||||
YOUR_TOKEN = hf_token
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
assert len(imgs) == rows * cols
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
# `pretrained_model_name_or_path` which Stable Diffusion checkpoint you want to use
|
||||
# Options: 1.) "stabilityai/stable-diffusion-2"
|
||||
# 2.) "stabilityai/stable-diffusion-2-base"
|
||||
# 3.) "CompVis/stable-diffusion-v1-4"
|
||||
# 4.) "runwayml/stable-diffusion-v1-5"
|
||||
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2"
|
||||
|
||||
# Add here the URLs to the images of the concept you are adding. 3-5 should be fine
|
||||
urls = [
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/2.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/3.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/5.jpeg",
|
||||
"https://huggingface.co/datasets/valhalla/images/resolve/main/6.jpeg",
|
||||
## You can add additional images here
|
||||
]
|
||||
|
||||
# Downloading Images
|
||||
import requests
|
||||
import glob
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def download_image(url):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
except:
|
||||
return None
|
||||
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
images = list(filter(None, [download_image(url) for url in urls]))
|
||||
save_path = "./my_concept"
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
|
||||
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
p.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default="my_concept/",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd_result",
|
||||
help="the directory contains the images used for fine tuning",
|
||||
)
|
||||
p.add_argument(
|
||||
"--training_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="the maximum number of training steps",
|
||||
)
|
||||
p.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The batch size for training",
|
||||
)
|
||||
p.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=250,
|
||||
help="the number of steps after which to save the learned concept",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=42, help="the random seed")
|
||||
p.add_argument(
|
||||
"--what_to_teach",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="what is it that you are teaching?",
|
||||
)
|
||||
p.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default="<cat-toy>",
|
||||
help="It is the token you are going to use to represent your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="toy",
|
||||
help="It is a word that can summarise what is your new concept",
|
||||
)
|
||||
p.add_argument(
|
||||
"--inference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="the number of steps for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--num_inference_samples",
|
||||
type=int,
|
||||
default=4,
|
||||
help="the number of samples for inference",
|
||||
)
|
||||
p.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="a grafitti in a wall with a *s on it",
|
||||
help="the text prompt to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="The device to use",
|
||||
)
|
||||
p.add_argument(
|
||||
"--use_torchdynamo",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="This flag is used to determine whether the training has to be done through the torchdynamo path or not.",
|
||||
)
|
||||
args = p.parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if "*s" not in args.prompt:
|
||||
raise ValueError(
|
||||
f'The prompt should have a "*s" which will be replaced by a placeholder token.'
|
||||
)
|
||||
|
||||
prompt1, prompt2 = args.prompt.split("*s")
|
||||
args.prompt = prompt1 + args.placeholder_token + prompt2
|
||||
|
||||
# `images_path` is a path to directory containing the training images.
|
||||
images_path = args.input_dir
|
||||
while not os.path.exists(str(images_path)):
|
||||
print(
|
||||
"The images_path specified does not exist, use the colab file explorer to copy the path :"
|
||||
)
|
||||
images_path = input("")
|
||||
save_path = images_path
|
||||
|
||||
# Setup and check the images you have just added
|
||||
images = []
|
||||
for file_path in os.listdir(save_path):
|
||||
try:
|
||||
image_path = os.path.join(save_path, file_path)
|
||||
images.append(Image.open(image_path).resize((512, 512)))
|
||||
except:
|
||||
print(
|
||||
f"{image_path} is not a valid image, please make sure to remove this file from the directory otherwise the training could fail."
|
||||
)
|
||||
image_grid(images, 1, len(images))
|
||||
|
||||
########### Create Dataset ##########
|
||||
|
||||
# Setup the prompt templates for training
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token.
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Get token ids for our placeholder and initializer token.
|
||||
# This code block will complain if initializer string is not a single token
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
# Load models and create wrapper for stable diffusion
|
||||
# pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
|
||||
# del pipeline
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="vae"
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="unet"
|
||||
)
|
||||
|
||||
# We have added the placeholder_token in the tokenizer so we resize the token embeddings here
|
||||
# this will a new embedding vector in the token embeddings for our placeholder_token
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# In Textual-Inversion we only train the newly added embedding vector
|
||||
# so lets freeze rest of the model parameters here
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
text_encoder.text_model.encoder.parameters(),
|
||||
text_encoder.text_model.final_layer_norm.parameters(),
|
||||
text_encoder.text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
|
||||
# Move vae and unet to device
|
||||
# For the dynamo path default compilation device is `cpu`, since torch-mlir
|
||||
# supports only that. Therefore, convert to device only for PyTorch path.
|
||||
if not args.use_torchdynamo:
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
|
||||
# Keep vae in eval mode as we don't train it
|
||||
vae.eval()
|
||||
# Keep unet in train mode to enable gradient checkpointing
|
||||
unet.train()
|
||||
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=save_path,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=100,
|
||||
learnable_property=args.what_to_teach, # Option selected above between object and style
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": args.training_steps,
|
||||
"save_steps": args.save_steps,
|
||||
"train_batch_size": args.train_batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# def save_progress(text_encoder, placeholder_token_id, accelerator, save_path):
|
||||
def save_progress(text_encoder, placeholder_token_id, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = (
|
||||
# accelerator.unwrap_model(text_encoder)
|
||||
text_encoder.get_input_embeddings().weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {
|
||||
args.placeholder_token: learned_embeds.detach().cpu()
|
||||
}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters["gradient_accumulation_steps"]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
grads = text_encoder.get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
grads.data[index_grads_to_zero, :] = grads.data[
|
||||
index_grads_to_zero, :
|
||||
].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params_ = [i for i in text_encoder.get_input_embeddings().parameters()]
|
||||
if args.use_torchdynamo:
|
||||
print("******** TRAINING STARTED - TORCHYDNAMO PATH ********")
|
||||
else:
|
||||
print("******** TRAINING STARTED - PYTORCH PATH ********")
|
||||
print("Initial weights:")
|
||||
print(params_, params_[0].shape)
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if args.use_torchdynamo:
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
# params[0].detach(),
|
||||
)
|
||||
else:
|
||||
loss = train_func(batch["pixel_values"], batch["input_ids"])
|
||||
print(loss)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % hyperparameters["save_steps"] == 0:
|
||||
save_path = os.path.join(
|
||||
output_dir,
|
||||
f"learned_embeds-step-{global_step}.bin",
|
||||
)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
save_path,
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
params__ = [i for i in text_encoder.get_input_embeddings().parameters()]
|
||||
print("******** TRAINING PROCESS FINISHED ********")
|
||||
print("Updated weights:")
|
||||
print(params__, params__[0].shape)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
# text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
)
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Also save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, f"learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, save_path)
|
||||
|
||||
|
||||
training_function()
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Set up the pipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
hyperparameters["output_dir"],
|
||||
scheduler=DPMSolverMultistepScheduler.from_pretrained(
|
||||
hyperparameters["output_dir"], subfolder="scheduler"
|
||||
),
|
||||
)
|
||||
if not args.use_torchdynamo:
|
||||
pipe.to(args.device)
|
||||
|
||||
# Run the Stable Diffusion pipeline
|
||||
# Don't forget to use the placeholder token in your prompt
|
||||
|
||||
all_images = []
|
||||
for _ in range(args.num_inference_samples):
|
||||
images = pipe(
|
||||
[args.prompt],
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=7.5,
|
||||
).images
|
||||
all_images.extend(images)
|
||||
|
||||
output_path = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.mkdir(args.output_dir)
|
||||
|
||||
[
|
||||
image.save(f"{args.output_dir}/{i}.jpeg")
|
||||
for i, image in enumerate(all_images)
|
||||
]
|
||||
86
shark/iree_eager_backend.py
Normal file
86
shark/iree_eager_backend.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
import iree
|
||||
import iree.runtime as ireert
|
||||
import numpy as np
|
||||
import torch
|
||||
from iree.runtime import DeviceArray
|
||||
from torch_mlir._mlir_libs._mlir.ir import Module
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
|
||||
TorchMLIREagerBackend,
|
||||
TensorMetaData,
|
||||
)
|
||||
from torch_mlir_e2e_test.eager_backends.refbackend import (
|
||||
NUMPY_TO_TORCH_DTYPE_DICT,
|
||||
)
|
||||
|
||||
from shark.iree_utils.compile_utils import (
|
||||
get_iree_compiled_module,
|
||||
IREE_DEVICE_MAP,
|
||||
)
|
||||
|
||||
|
||||
class EagerModeIREELinalgOnTensorsBackend(TorchMLIREagerBackend):
|
||||
"""Main entry-point for the iree backend for torch-mlir eager mode.
|
||||
|
||||
EagerModeIREELinalgOnTensorsBackend uses iree.DeviceArray representations of tensors and
|
||||
thus all of the wrapping and unwrapping and munging here is done to between torch.Tensor and iree.DeviceArray,
|
||||
with np.ndarray as an intermediary.
|
||||
"""
|
||||
|
||||
def __init__(self, device: str):
|
||||
self.torch_device_str = device
|
||||
self.config = ireert.Config(IREE_DEVICE_MAP[device])
|
||||
self.raw_device_str = device
|
||||
|
||||
def get_torch_metadata(
|
||||
self, tensor: DeviceArray, kwargs: Dict[str, Any]
|
||||
) -> TensorMetaData:
|
||||
return TensorMetaData(
|
||||
size=tensor.shape,
|
||||
dtype=NUMPY_TO_TORCH_DTYPE_DICT[tensor.dtype.type],
|
||||
device=torch.device(self.torch_device_str),
|
||||
requires_grad=tensor.dtype.type
|
||||
in {np.float, np.float32, np.float64}
|
||||
and kwargs.get("requires_grad", False),
|
||||
)
|
||||
|
||||
def compile(self, imported_module: Module):
|
||||
run_pipeline_with_repro_report(
|
||||
imported_module,
|
||||
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
|
||||
"EagerMode",
|
||||
)
|
||||
callable, _ = get_iree_compiled_module(
|
||||
imported_module, self.raw_device_str
|
||||
)
|
||||
return callable
|
||||
|
||||
def copy_into(self, dst, src):
|
||||
"""Copy output back to appropriate arg that it should alias."""
|
||||
np.copyto(dst, src)
|
||||
|
||||
def transfer_from_device_to_torch(self, e):
|
||||
return torch.from_numpy(e.to_host())
|
||||
|
||||
def transfer_from_torch_to_device(
|
||||
self, tensor: torch.Tensor
|
||||
) -> DeviceArray:
|
||||
return iree.runtime.asdevicearray(self.config.device, tensor.numpy())
|
||||
0
shark/iree_utils/__init__.py
Normal file
0
shark/iree_utils/__init__.py
Normal file
164
shark/iree_utils/_common.py
Normal file
164
shark/iree_utils/_common.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
## Common utilities to be shared by iree utilities.
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
print(cmd)
|
||||
print("\n\n")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
check=True,
|
||||
)
|
||||
stdout = result.stdout.decode()
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
uri_parts = device.split("://", 2)
|
||||
iree_driver = (
|
||||
_IREE_DEVICE_MAP[uri_parts[0]]
|
||||
if uri_parts[0] in _IREE_DEVICE_MAP
|
||||
else uri_parts[0]
|
||||
)
|
||||
if len(uri_parts) == 1:
|
||||
return iree_driver
|
||||
elif "rocm" in uri_parts:
|
||||
return "rocm"
|
||||
else:
|
||||
return f"{iree_driver}://{uri_parts[1]}"
|
||||
|
||||
|
||||
def get_supported_device_list():
|
||||
return list(_IREE_DEVICE_MAP.keys())
|
||||
|
||||
|
||||
_IREE_DEVICE_MAP = {
|
||||
"cpu": "local-task",
|
||||
"cpu-task": "local-task",
|
||||
"cpu-sync": "local-sync",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "hip",
|
||||
"intel-gpu": "level_zero",
|
||||
}
|
||||
|
||||
|
||||
def iree_target_map(device):
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
return _IREE_TARGET_MAP[device] if device in _IREE_TARGET_MAP else device
|
||||
|
||||
|
||||
_IREE_TARGET_MAP = {
|
||||
"cpu": "llvm-cpu",
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"hip": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
}
|
||||
|
||||
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
154
shark/iree_utils/benchmark_utils.py
Normal file
154
shark/iree_utils/benchmark_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import platform
|
||||
|
||||
UNIT_TO_SECOND_MAP = {"us": 1e-6, "ms": 0.001, "s": 1}
|
||||
|
||||
|
||||
def tensor_to_type_str(input_tensors: tuple, mlir_dialect: str):
|
||||
"""
|
||||
Input: A tuple of input tensors i.e tuple(torch.tensor)
|
||||
Output: list of string that represent mlir types (i.e 1x24xf64)
|
||||
# TODO: Support more than floats, and ints
|
||||
"""
|
||||
list_of_type = []
|
||||
for input_tensor in input_tensors:
|
||||
type_string = "x".join([str(dim) for dim in input_tensor.shape])
|
||||
if mlir_dialect in ["linalg", "tosa"]:
|
||||
dtype_string = str(input_tensor.dtype).replace("torch.", "")
|
||||
elif mlir_dialect in ["mhlo", "tflite"]:
|
||||
dtype = input_tensor.dtype
|
||||
try:
|
||||
dtype_string = re.findall("'[^\"]*'", str(dtype))[0].replace(
|
||||
"'", ""
|
||||
)
|
||||
except IndexError:
|
||||
dtype_string = str(dtype)
|
||||
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
|
||||
match = regex_split.match(dtype_string)
|
||||
mlir_type_string = str(match.group(1)[0]) + str(match.group(2))
|
||||
type_string += f"x{mlir_type_string}"
|
||||
list_of_type.append(type_string)
|
||||
return list_of_type
|
||||
|
||||
|
||||
def build_benchmark_args(
|
||||
input_file: str,
|
||||
device: str,
|
||||
input_tensors: tuple,
|
||||
mlir_dialect: str,
|
||||
training=False,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
fn_name = "forward"
|
||||
if training == True:
|
||||
# TODO: Replace name of train with actual train fn name.
|
||||
fn_name = "train"
|
||||
benchmark_cl.append(f"--function={fn_name}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
|
||||
for mlir_input in mlir_input_types:
|
||||
benchmark_cl.append(f"--input={mlir_input}")
|
||||
if device == "cpu":
|
||||
num_cpus = get_cpu_count()
|
||||
if num_cpus is not None:
|
||||
benchmark_cl.append(f"--task_topology_max_group_count={num_cpus}")
|
||||
# if time_extractor:
|
||||
# benchmark_cl.append(time_extractor)
|
||||
benchmark_cl.append(f"--print_statistics=true")
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def build_benchmark_args_non_tensor_input(
|
||||
input_file: str,
|
||||
device: str,
|
||||
inputs: tuple,
|
||||
mlir_dialect: str,
|
||||
function_name: str,
|
||||
):
|
||||
"""
|
||||
Inputs: input_file leading to vmfb, input_tensor to function, target device,
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
if function_name:
|
||||
benchmark_cl.append(f"--function={function_name}")
|
||||
benchmark_cl.append(f"--device={iree_device_map(device)}")
|
||||
for input in inputs:
|
||||
benchmark_cl.append(f"--input={input}")
|
||||
if platform.system() != "Windows":
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl.append(time_extractor)
|
||||
return benchmark_cl
|
||||
|
||||
|
||||
def run_benchmark_module(benchmark_cl):
|
||||
"""
|
||||
Run benchmark command, extract result and return iteration/seconds, host
|
||||
peak memory, and device peak memory.
|
||||
|
||||
# TODO: Add an example of the benchmark command.
|
||||
Input: benchmark command.
|
||||
"""
|
||||
benchmark_path = benchmark_cl[0]
|
||||
assert os.path.exists(
|
||||
benchmark_path
|
||||
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
|
||||
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
|
||||
try:
|
||||
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(3)
|
||||
except AttributeError:
|
||||
regex_split = re.compile("(\d+[.]*\d*)([a-zA-Z]+)")
|
||||
match = regex_split.search(bench_stdout)
|
||||
time_ms = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
iter_per_second = 1.0 / (time_ms * 0.001)
|
||||
|
||||
# Extract peak memory.
|
||||
host_regex = re.compile(r".*HOST_LOCAL:\s*([0-9]+)B peak")
|
||||
host_peak_b = int(host_regex.search(bench_stderr).group(1))
|
||||
device_regex = re.compile(r".*DEVICE_LOCAL:\s*([0-9]+)B peak")
|
||||
device_peak_b = int(device_regex.search(bench_stderr).group(1))
|
||||
return iter_per_second, host_peak_b, device_peak_b
|
||||
704
shark/iree_utils/compile_utils.py
Normal file
704
shark/iree_utils/compile_utils.py
Normal file
@@ -0,0 +1,704 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.parser import shark_args
|
||||
|
||||
from .trace import DetailLogger
|
||||
from ._common import iree_device_map, iree_target_map
|
||||
from .cpu_utils import get_iree_cpu_rt_args
|
||||
from .benchmark_utils import *
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device, device_num = clean_device_info(device)
|
||||
|
||||
if "cpu" in device:
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
return (
|
||||
get_iree_cpu_args()
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
)
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
if device == "hip":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True)
|
||||
return []
|
||||
|
||||
def get_iree_target_triple(device):
|
||||
args = get_iree_device_args(device)
|
||||
for flag in args:
|
||||
if "triple" in flag:
|
||||
triple = flag.split("=")[-1]
|
||||
return triple
|
||||
return ""
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["hip", "rocm", "vulkan"]:
|
||||
device_id = None
|
||||
if device in ["hip", "rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
return ["--iree-llvmcpu-target-cpu-features=host"]
|
||||
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
|
||||
return [
|
||||
"--iree-llvmcpu-target-cpu-features=host",
|
||||
"--iree-input-demote-i64-to-i32",
|
||||
]
|
||||
else:
|
||||
# Frontend not found.
|
||||
return []
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--mlir-elide-elementsattrs-if-larger=10",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
# shark_args are passed down from pytests to control which models compile with these flags,
|
||||
# but they can also be set in shark/parser.py
|
||||
def get_model_specific_args():
|
||||
ms_args = []
|
||||
if shark_args.enable_conv_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
|
||||
]
|
||||
if shark_args.enable_img2col_transform == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
|
||||
]
|
||||
if shark_args.use_winograd == True:
|
||||
ms_args += [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
return ms_args
|
||||
|
||||
|
||||
def create_dispatch_dirs(bench_dir, device):
|
||||
protected_files = ["ordered-dispatches.txt"]
|
||||
bench_dir_path = bench_dir.split("/")
|
||||
bench_dir_path[-1] = "temp_" + bench_dir_path[-1]
|
||||
tmp_bench_dir = "/".join(bench_dir_path)
|
||||
for f_ in os.listdir(bench_dir):
|
||||
if os.path.isfile(f"{bench_dir}/{f_}") and f_ not in protected_files:
|
||||
dir_name = re.sub("\.\S*$", "", f_)
|
||||
if os.path.exists(f"{bench_dir}/{dir_name}"):
|
||||
os.system(f"rm -rf {bench_dir}/{dir_name}")
|
||||
os.system(f"mkdir {bench_dir}/{dir_name}")
|
||||
os.system(f"mv {bench_dir}/{f_} {bench_dir}/{dir_name}/{f_}")
|
||||
for f_ in os.listdir(tmp_bench_dir):
|
||||
if os.path.isfile(f"{tmp_bench_dir}/{f_}"):
|
||||
dir_name = ""
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if re.search(f"{d_}(?=\D)", f_):
|
||||
dir_name = d_
|
||||
if dir_name != "":
|
||||
os.system(
|
||||
f"mv {tmp_bench_dir}/{f_} {bench_dir}/{dir_name}/{dir_name}_benchmark.mlir"
|
||||
)
|
||||
|
||||
|
||||
def dump_isas(bench_dir):
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
if f_.endswith(".spv"):
|
||||
os.system(
|
||||
f"amdllpc -gfxip 11.0 {bench_dir}/{d_}/{f_} -v > \
|
||||
{bench_dir}/{d_}/isa.txt"
|
||||
)
|
||||
|
||||
|
||||
def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
|
||||
benchmark_runtimes = {}
|
||||
dispatch_list = []
|
||||
all_dispatches = False
|
||||
|
||||
if dispatch_benchmarks.lower().strip() == "all":
|
||||
all_dispatches = True
|
||||
else:
|
||||
try:
|
||||
dispatch_list = [
|
||||
int(dispatch_index)
|
||||
for dispatch_index in dispatch_benchmarks.split(" ")
|
||||
]
|
||||
except:
|
||||
print("ERROR: Invalid dispatch benchmarks")
|
||||
return None
|
||||
for d_ in os.listdir(bench_dir):
|
||||
if os.path.isdir(f"{bench_dir}/{d_}"):
|
||||
in_dispatches = False
|
||||
for dispatch in dispatch_list:
|
||||
if str(dispatch) in d_:
|
||||
in_dispatches = True
|
||||
if all_dispatches or in_dispatches:
|
||||
for f_ in os.listdir(f"{bench_dir}/{d_}"):
|
||||
if "benchmark.mlir" in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module, target_backends=[iree_target_map(device)]
|
||||
)
|
||||
|
||||
vmfb_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.vmfb", "wb"
|
||||
)
|
||||
vmfb_file.write(flatbuffer_blob)
|
||||
vmfb_file.close()
|
||||
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance,
|
||||
flatbuffer_blob,
|
||||
warn_if_copy=False,
|
||||
)
|
||||
|
||||
benchmark_cl = build_benchmark_args_non_tensor_input(
|
||||
input_file=f"{bench_dir}/{d_}/{d_}_benchmark.vmfb",
|
||||
device=device,
|
||||
inputs=(0,),
|
||||
mlir_dialect="linalg",
|
||||
function_name="",
|
||||
)
|
||||
|
||||
benchmark_bash = open(
|
||||
f"{bench_dir}/{d_}/{d_}_benchmark.sh", "w+"
|
||||
)
|
||||
benchmark_bash.write("#!/bin/bash\n")
|
||||
benchmark_bash.write(" ".join(benchmark_cl))
|
||||
benchmark_bash.close()
|
||||
|
||||
iter_per_second, _, _ = run_benchmark_module(
|
||||
benchmark_cl
|
||||
)
|
||||
|
||||
benchmark_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_data.txt", "w+"
|
||||
)
|
||||
benchmark_file.write(f"DISPATCH: {d_}\n")
|
||||
benchmark_file.write(str(iter_per_second) + "\n")
|
||||
benchmark_file.write(
|
||||
"SHARK BENCHMARK RESULT: "
|
||||
+ str(1 / (iter_per_second * 0.001))
|
||||
+ "\n"
|
||||
)
|
||||
benchmark_file.close()
|
||||
|
||||
benchmark_runtimes[d_] = 1 / (iter_per_second * 0.001)
|
||||
|
||||
elif ".mlir" in f_ and "benchmark" not in f_:
|
||||
dispatch_file = open(f"{bench_dir}/{d_}/{f_}", "r")
|
||||
module = dispatch_file.read()
|
||||
dispatch_file.close()
|
||||
|
||||
module = re.sub(
|
||||
"hal.executable private",
|
||||
"hal.executable public",
|
||||
module,
|
||||
)
|
||||
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=["--compile-mode=hal-executable"],
|
||||
)
|
||||
|
||||
spirv_file = open(
|
||||
f"{bench_dir}/{d_}/{d_}_spirv.vmfb", "wb"
|
||||
)
|
||||
spirv_file.write(flatbuffer_blob)
|
||||
spirv_file.close()
|
||||
|
||||
ordered_dispatches = [
|
||||
(k, v)
|
||||
for k, v in sorted(
|
||||
benchmark_runtimes.items(), key=lambda item: item[1]
|
||||
)
|
||||
][::-1]
|
||||
f_ = open(f"{bench_dir}/ordered-dispatches.txt", "w+")
|
||||
for dispatch in ordered_dispatches:
|
||||
f_.write(f"{dispatch[0]}: {dispatch[1]}ms\n")
|
||||
f_.close()
|
||||
|
||||
|
||||
def compile_module_to_flatbuffer(
|
||||
module,
|
||||
device,
|
||||
frontend,
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
compile_str=False,
|
||||
write_to=None,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = "auto"
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
args += shark_args.additional_compile_args
|
||||
|
||||
if frontend in ["tensorflow", "tf"]:
|
||||
input_type = "auto"
|
||||
elif frontend in ["stablehlo", "tosa"]:
|
||||
input_type = frontend
|
||||
elif frontend in ["tflite", "tflite-tosa"]:
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
input_type = "torch"
|
||||
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
module,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
input_type=input_type,
|
||||
)
|
||||
else:
|
||||
assert os.path.isfile(module)
|
||||
flatbuffer_blob = ireec.compile_file(
|
||||
str(module),
|
||||
input_type=input_type,
|
||||
target_backends=[iree_target_map(device)],
|
||||
extra_args=args,
|
||||
)
|
||||
|
||||
if write_to is not None:
|
||||
with open(write_to, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
return None
|
||||
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file=None,
|
||||
):
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
config.vm_instance, flatbuffer_blob, warn_if_copy=False
|
||||
)
|
||||
modules = []
|
||||
if external_weight_file is not None:
|
||||
modules.append(index.create_provider(scope="model"))
|
||||
ctx = ireert.SystemContext(vm_modules=modules, config=config)
|
||||
ctx.add_vm_module(vm_module)
|
||||
ModuleCompiled = getattr(ctx.modules, vm_module.name)
|
||||
return ModuleCompiled, config
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
external_weight_file: str = None,
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
dl.log(f"Mapping device id: {device_idx}")
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
dl.log("get_iree_runtime_config")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(
|
||||
config.vm_instance, flatbuffer_blob_or_path
|
||||
)
|
||||
vm_modules = []
|
||||
if external_weight_file is not None:
|
||||
index = ireert.ParameterIndex()
|
||||
index.load(external_weight_file)
|
||||
param_module = ireert.create_io_parameters_module(
|
||||
config.vm_instance, index.create_provider(scope="model")
|
||||
)
|
||||
vm_modules.append(param_module)
|
||||
vm_modules.append(mmaped_vmfb)
|
||||
vm_modules.append(
|
||||
ireert.create_hal_module(config.vm_instance, config.device)
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
print(
|
||||
"\tCompiling Vulkan shaders. This may take a few minutes."
|
||||
)
|
||||
ctx = ireert.SystemContext(config=config, vm_modules=vm_modules)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
for flag in shark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
dl.log(f"mmap temp {vmfb_file_path}")
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
module,
|
||||
device: str,
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
external_weight_file: str = None,
|
||||
write_to: bool = None,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=frontend,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
write_to=write_to,
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
# Got to find a cleaner way to unlink/delete the temporary file since
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
if write_to is not None:
|
||||
flatbuffer_blob = write_to
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx,
|
||||
rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
external_weight_file=external_weight_file,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def load_flatbuffer(
|
||||
flatbuffer_path: str,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
"config": config,
|
||||
"temp_file_to_unlink": temp_file_to_unlink,
|
||||
}
|
||||
return ret_params
|
||||
|
||||
|
||||
def export_iree_module_to_vmfb(
|
||||
module,
|
||||
device: str,
|
||||
directory: str,
|
||||
mlir_dialect: str = "linalg",
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
compile_str: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module=module,
|
||||
device=device,
|
||||
frontend=mlir_dialect,
|
||||
model_config_path=model_config_path,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
compile_str=compile_str,
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
device if "://" not in device else "-".join(device.split("://"))
|
||||
)
|
||||
module_name = f"{mlir_dialect}_{device_name}"
|
||||
filename = os.path.join(directory, module_name + ".vmfb")
|
||||
with open(filename, "wb") as f:
|
||||
f.write(flatbuffer_blob)
|
||||
print(f"Saved vmfb in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def export_module_to_mlir_file(module, frontend, directory: str):
|
||||
# TODO: write proper documentation.
|
||||
mlir_str = module
|
||||
if frontend in ["tensorflow", "tf", "mhlo", "stablehlo", "tflite"]:
|
||||
mlir_str = module.decode("utf-8")
|
||||
elif frontend in ["pytorch", "torch"]:
|
||||
mlir_str = module.operation.get_asm()
|
||||
filename = os.path.join(directory, "model.mlir")
|
||||
with open(filename, "w") as f:
|
||||
f.write(mlir_str)
|
||||
print(f"Saved mlir in {filename}.")
|
||||
return filename
|
||||
|
||||
|
||||
def get_results(
|
||||
compiled_vm,
|
||||
function_name,
|
||||
input,
|
||||
config,
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm" and hasattr(config, "id"):
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
ireert.asdevicearray(config.device, input_array)
|
||||
)
|
||||
dl.log(f"Invoke function: {function_name}")
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
dl.log(f"Invoke complete")
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
dl.log(f"Result to host: {val.shape}")
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
if send_to_host and result is not None:
|
||||
dl.log("Result to host")
|
||||
return result.to_host()
|
||||
return result
|
||||
dl.log("Execution complete")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if "metal" in device and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
65
shark/iree_utils/cpu_utils.py
Normal file
65
shark/iree_utils/cpu_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
import multiprocessing
|
||||
|
||||
try:
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
return cpu_count
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
|
||||
# Get the default cpu args.
|
||||
@functools.cache
|
||||
def get_iree_cpu_args():
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
|
||||
if os_name == "Darwin":
|
||||
kernel_version = uname.release
|
||||
target_triple = f"{proc_name}-apple-darwin{kernel_version}"
|
||||
elif os_name == "Linux":
|
||||
target_triple = f"{proc_name}-linux-gnu"
|
||||
elif os_name == "Windows":
|
||||
target_triple = "x86_64-pc-windows-msvc"
|
||||
else:
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [
|
||||
f"--iree-llvmcpu-target-triple={target_triple}",
|
||||
]
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
@functools.cache
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if shark_args.task_topology_max_group_count is None
|
||||
else shark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
209
shark/iree_utils/gpu_utils.py
Normal file
209
shark/iree_utils/gpu_utils.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_gpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
# TODO: refactor to rocm and cuda utils
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
# TODO: Give the user_interface to pass the sm_arch.
|
||||
sm_arch = get_cuda_sm_cc()
|
||||
if (
|
||||
sm_arch
|
||||
in ["sm_70", "sm_72", "sm_75", "sm_80", "sm_84", "sm_86", "sm_89"]
|
||||
) and (shark_args.enable_tf32 == True):
|
||||
return [
|
||||
f"--iree-hal-cuda-llvm-target-arch={sm_arch}",
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def check_rocm_device_arch_in_args(extra_args):
|
||||
# Check if the target arch flag for rocm device present in extra_args
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_arch = flag.split("=")[1]
|
||||
return flag_arch
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
|
||||
if arch_in_flag is not None:
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
|
||||
)
|
||||
return arch_in_flag
|
||||
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump, driver):
|
||||
from os import linesep
|
||||
|
||||
if driver == "hip":
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "AMD" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
else:
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
driver = "hip" if hip_driver else "rocm"
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=" + driver, raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=" + driver + "`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver)
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
rocm_flags = []
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
CUDA_SUCCESS = 0
|
||||
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16
|
||||
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39
|
||||
CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
|
||||
for libname in libnames:
|
||||
try:
|
||||
cuda = ctypes.CDLL(libname)
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise OSError("could not load any of: " + " ".join(libnames))
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
name = b" " * 100
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
context = ctypes.c_void_p()
|
||||
error_str = ctypes.c_char_p()
|
||||
|
||||
result = cuda.cuInit(0)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuInit failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
result = cuda.cuDeviceGetCount(ctypes.byref(nGpus))
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGetCount failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Found %d device(s)." % nGpus.value)
|
||||
for i in range(nGpus.value):
|
||||
result = cuda.cuDeviceGet(ctypes.byref(device), i)
|
||||
if result != CUDA_SUCCESS:
|
||||
cuda.cuGetErrorString(result, ctypes.byref(error_str))
|
||||
print(
|
||||
"cuDeviceGet failed with error code %d: %s"
|
||||
% (result, error_str.value.decode())
|
||||
)
|
||||
return 1
|
||||
print("Device: %d" % i)
|
||||
if (
|
||||
cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(" Name: %s" % (name.split(b"\0", 1)[0].decode()))
|
||||
if (
|
||||
cuda.cuDeviceComputeCapability(
|
||||
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
|
||||
)
|
||||
== CUDA_SUCCESS
|
||||
):
|
||||
print(
|
||||
" Compute Capability: %d.%d"
|
||||
% (cc_major.value, cc_minor.value)
|
||||
)
|
||||
sm = f"sm_{cc_major.value}{cc_minor.value}"
|
||||
return sm
|
||||
102
shark/iree_utils/metal_utils.py
Normal file
102
shark/iree_utils/metal_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_metal_device_name(device_num=0):
|
||||
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
||||
iree_device_dump = iree_device_dump[0].split("\n\n")
|
||||
metal_device_list = [
|
||||
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
|
||||
]
|
||||
if len(metal_device_list) == 0:
|
||||
raise ValueError("No device name found in device dump!")
|
||||
if len(metal_device_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(metal_device_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {metal_device_list[device_num]}")
|
||||
return metal_device_list[device_num]
|
||||
|
||||
|
||||
def get_os_name():
|
||||
if platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif platform == "darwin":
|
||||
return "macos"
|
||||
elif platform == "win32":
|
||||
return "windows"
|
||||
else:
|
||||
print("Cannot detect OS type, defaulting to linux.")
|
||||
return "linux"
|
||||
|
||||
|
||||
def get_metal_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
Args:
|
||||
device_name (str): name of the hardware device to be used with vulkan
|
||||
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
return "macos"
|
||||
|
||||
|
||||
def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
for flag in extra_args:
|
||||
if "-iree-metal-target-platform=" in flag:
|
||||
print(f"Using target triple {flag.split('=')[1]}")
|
||||
return None
|
||||
|
||||
if device_name == "" or device_name == [] or device_name is None:
|
||||
metal_device = get_metal_device_name(device_num=device_num)
|
||||
else:
|
||||
metal_device = device_name
|
||||
triple = get_metal_target_triple(metal_device)
|
||||
if triple is not None:
|
||||
print(
|
||||
f"Found metal device {metal_device}. Using metal target platform {triple}"
|
||||
)
|
||||
return f"-iree-metal-target-platform={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
or pull up an issue."""
|
||||
)
|
||||
print(f"Target : {metal_device}")
|
||||
return None
|
||||
|
||||
|
||||
def get_iree_metal_args(device_num=0, extra_args=[]):
|
||||
# Add any metal spefic compilation flags here
|
||||
res_metal_flag = []
|
||||
if len(extra_args) > 0:
|
||||
res_metal_flag.extend(extra_args)
|
||||
return res_metal_flag
|
||||
|
||||
|
||||
def set_iree_metal_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
return
|
||||
76
shark/iree_utils/trace.py
Normal file
76
shark/iree_utils/trace.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
def _enable_detail_trace() -> bool:
|
||||
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"
|
||||
|
||||
|
||||
class DetailLogger:
|
||||
"""Context manager which can accumulate detailed log messages.
|
||||
|
||||
Detailed log is only emitted if the operation takes a long time
|
||||
or errors.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float):
|
||||
self._timeout = timeout
|
||||
self._messages: List[Tuple[float, str]] = []
|
||||
self._start_time = time.time()
|
||||
self._active = not _enable_detail_trace()
|
||||
self._lock = threading.RLock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
self._thread = None
|
||||
|
||||
def __enter__(self):
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with self._lock:
|
||||
self._active = False
|
||||
self._cond.notify()
|
||||
if traceback:
|
||||
self.dump_on_error(f"exception")
|
||||
|
||||
def _run(self):
|
||||
with self._lock:
|
||||
timed_out = not self._cond.wait(self._timeout)
|
||||
if timed_out:
|
||||
self.dump_on_error(f"took longer than {self._timeout}s")
|
||||
|
||||
def log(self, msg):
|
||||
with self._lock:
|
||||
timestamp = time.time()
|
||||
if self._active:
|
||||
self._messages.append((timestamp, msg))
|
||||
else:
|
||||
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
|
||||
|
||||
def dump_on_error(self, summary: str):
|
||||
with self._lock:
|
||||
if self._active:
|
||||
print(f"::: Detailed report ({summary}):")
|
||||
for timestamp, msg in self._messages:
|
||||
print(
|
||||
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
|
||||
)
|
||||
self._active = False
|
||||
538
shark/iree_utils/vulkan_target_env_utils.py
Normal file
538
shark/iree_utils/vulkan_target_env_utils.py
Normal file
@@ -0,0 +1,538 @@
|
||||
# Copyright 2020 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
# get version
|
||||
version = get_version(triple=triple)
|
||||
# TODO get revision
|
||||
revision = 120
|
||||
|
||||
# extensions
|
||||
extensions = get_extensions(triple)
|
||||
# get vendor
|
||||
vendor = get_vendor(triple)
|
||||
# get device type
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
def get_vulkan_target_env_flag(vulkan_target_triple):
|
||||
target_env = get_vulkan_target_env(vulkan_target_triple)
|
||||
target_env_flag = f"--iree-vulkan-target-env={target_env}"
|
||||
return target_env_flag
|
||||
|
||||
|
||||
def get_version(triple):
|
||||
arch, product, os = triple
|
||||
if os in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if product in ["android30", "android31"]:
|
||||
return "v1.1"
|
||||
if arch in ["unknown"]:
|
||||
return "v1.1"
|
||||
return "v1.3"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ", ".join(ext_list)
|
||||
return f"[{res}]"
|
||||
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("SPV_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["SPV_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("SPV_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("SPV_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vendor(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn4", "rgcn5"]:
|
||||
return "AMD"
|
||||
if arch == "valhall":
|
||||
return "ARM"
|
||||
if arch == "m1":
|
||||
return "Apple"
|
||||
if arch in ["arc", "UHD"]:
|
||||
return "Intel"
|
||||
if arch in ["turing", "ampere", "pascal"]:
|
||||
return "NVIDIA"
|
||||
if arch == "adreno":
|
||||
return "Qualcomm"
|
||||
if arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
return "SwiftShader"
|
||||
return "Unknown"
|
||||
print(f"Vendor for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
return "Unknown"
|
||||
if arch == "cpu":
|
||||
return "CPU"
|
||||
if arch in ["turing", "ampere", "arc", "pascal"]:
|
||||
return "DiscreteGPU"
|
||||
if arch in ["rdna1", "rdna2", "rdna3", "rgcn3", "rgcn5"]:
|
||||
if product == "ivega10":
|
||||
return "IntegratedGPU"
|
||||
return "DiscreteGPU"
|
||||
if arch in ["m1", "valhall", "adreno"]:
|
||||
return "IntegratedGPU"
|
||||
print(f"Device type for target triple - {triple} not found. Using unknown")
|
||||
return "Unknown"
|
||||
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
@functools.cache
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
|
||||
cap = OrderedDict()
|
||||
arch, product, os = triple
|
||||
subgroup_feature = {
|
||||
"Basic": 1,
|
||||
"Vote": 2,
|
||||
"Arithmetic": 4,
|
||||
"Ballot": 8,
|
||||
"Shuffle": 16,
|
||||
"ShuffleRelative": 32,
|
||||
"Clustered": 64,
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["max_compute_workgroup_invocations"] = 128
|
||||
cap["max_compute_workgroup_size"] = [128, 128, 64]
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["min_subgroup_size"] = None
|
||||
cap["max_subgroup_size"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
cap["shaderInt16"] = False
|
||||
cap["shaderInt64"] = False
|
||||
cap["storageBuffer16BitAccess"] = False
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = False
|
||||
cap["storageBuffer8BitAccess"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = False
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["min_subgroup_size"] = 64
|
||||
cap["max_subgroup_size"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storagePushConstant16"] = False
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = False
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = False
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 512
|
||||
cap["max_compute_workgroup_size"] = [512, 512, 512]
|
||||
|
||||
cap["subgroup_size"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
if os == "android31":
|
||||
cap["subgroupFeatures"].append("Shuffle")
|
||||
cap["subgroupFeatures"].append("ShuffleRelative")
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "arc":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = False
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["subgroup_size"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1536
|
||||
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderFloat64"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
cap["shaderInt64"] = True
|
||||
cap["shaderIntegerDotProduct"] = True
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
cap["storagePushConstant16"] = True
|
||||
cap["uniformAndStorageBuffer16BitAccess"] = True
|
||||
cap["storageBuffer8BitAccess"] = True
|
||||
cap["storagePushConstant8"] = True
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
"Arithmetic",
|
||||
"Ballot",
|
||||
"Shuffle",
|
||||
"ShuffleRelative",
|
||||
"Quad",
|
||||
]
|
||||
|
||||
cap["shaderFloat16"] = True
|
||||
cap["shaderInt8"] = True
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "android31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroup_size"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
print(
|
||||
f"Architecture {arch} not matched. Using default vulkan target device capability"
|
||||
)
|
||||
|
||||
def get_comma_sep_str(ele_list):
|
||||
l = ""
|
||||
for ele in ele_list:
|
||||
l += f"{ele}, "
|
||||
l = f"[{l[:-2]}]"
|
||||
return l
|
||||
|
||||
res = ""
|
||||
for k, v in cap.items():
|
||||
if v is None or v == False:
|
||||
continue
|
||||
if isinstance(v, bool):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "max_compute_workgroup_size":
|
||||
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
|
||||
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
res += f"{k} = {v}, "
|
||||
res = res[:-2]
|
||||
return res
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user