mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add --additional_runtime_args option and use in OPT example. (#1855)
* Add --additional_runtime_args option and use in OPT example. Fix the func name. (#1838) Co-authored-by: Sungsoon Cho <sungsoon.cho@gmail.com>
This commit is contained in:
@@ -13,9 +13,7 @@ arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
|
||||
print("Running shark on cpu backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, device="cpu", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
|
||||
# Generate the random inputs and feed into the graph.
|
||||
x = shark_module.generate_random_inputs()
|
||||
@@ -23,15 +21,11 @@ shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on cuda backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, device="cuda", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="cuda", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
print("Running shark on vulkan backend")
|
||||
shark_module = SharkInference(
|
||||
mhlo_ir, device="vulkan", mlir_dialect="mhlo"
|
||||
)
|
||||
shark_module = SharkInference(mhlo_ir, device="vulkan", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
print(shark_module.forward(x))
|
||||
|
||||
@@ -8,9 +8,7 @@ mlir_model, func_name, inputs, golden_out = download_model(
|
||||
)
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device="cpu", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward(inputs)
|
||||
print("The obtained result via shark is: ", result)
|
||||
|
||||
@@ -49,9 +49,7 @@ module = torch_mlir.compile(
|
||||
mlir_model = module
|
||||
func_name = "forward"
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_model, device="cuda", mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(mlir_model, device="cuda", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
|
||||
|
||||
|
||||
@@ -333,8 +333,12 @@ def compile_module_to_flatbuffer(
|
||||
return flatbuffer_blob
|
||||
|
||||
|
||||
def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
def get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=None, rt_flags: list = []
|
||||
):
|
||||
# Returns the compiled module and the configs.
|
||||
for flag in rt_flags:
|
||||
ireert.flags.parse_flag(flag)
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
@@ -356,9 +360,22 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
|
||||
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
flatbuffer_blob_or_path,
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
rt_flags.append(flag)
|
||||
for flag in rt_flags:
|
||||
print(flag)
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
if "rocm" in device:
|
||||
device = "rocm"
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
@@ -385,6 +402,7 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
@@ -404,6 +422,8 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
for flag in shark_args.additional_runtime_args:
|
||||
ireert.flags.parse_flags(flag)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
if "vulkan" in device:
|
||||
# Vulkan pipeline creation consumes significant amount of time.
|
||||
@@ -430,6 +450,7 @@ def get_iree_compiled_module(
|
||||
frontend: str = "torch",
|
||||
model_config_path: str = None,
|
||||
extra_args: list = [],
|
||||
rt_flags: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
@@ -452,11 +473,14 @@ def get_iree_compiled_module(
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
flatbuffer_blob, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
@@ -471,17 +495,21 @@ def load_flatbuffer(
|
||||
device: str,
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
flatbuffer_path, device, device_idx, rt_flags
|
||||
)
|
||||
else:
|
||||
with open(os.path.join(flatbuffer_path), "rb") as f:
|
||||
flatbuffer_blob = f.read()
|
||||
vmfb, config = get_iree_module(
|
||||
flatbuffer_blob, device, device_idx=device_idx
|
||||
flatbuffer_blob,
|
||||
device,
|
||||
device_idx=device_idx,
|
||||
rt_flags=rt_flags,
|
||||
)
|
||||
ret_params = {
|
||||
"vmfb": vmfb,
|
||||
|
||||
@@ -26,7 +26,7 @@ class SplitStrToListAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
del parser, option_string
|
||||
setattr(namespace, self.dest, shlex.split(values[0]))
|
||||
setattr(namespace, self.dest, shlex.split(" "))
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="SHARK runner.")
|
||||
@@ -44,6 +44,13 @@ parser.add_argument(
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional_runtime_args",
|
||||
default=list(),
|
||||
nargs=1,
|
||||
action=SplitStrToListAction,
|
||||
help="Additional arguments to pass to the IREE runtime. These are appended as the last arguments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tf32",
|
||||
type=bool,
|
||||
|
||||
@@ -73,6 +73,7 @@ class SharkInference:
|
||||
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
|
||||
device_idx: int = None,
|
||||
mmap: bool = True,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if mlir_module is not None:
|
||||
@@ -100,6 +101,7 @@ class SharkInference:
|
||||
|
||||
self.shark_runner = None
|
||||
self.mmap = mmap
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
def compile(self, extra_args=[]):
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -134,6 +136,7 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
extra_args=extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
|
||||
if self.dispatch_benchmarks is not None:
|
||||
@@ -220,12 +223,14 @@ class SharkInference:
|
||||
device=self.device,
|
||||
compile_vmfb=False,
|
||||
extra_args=extra_args,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
path,
|
||||
self.device,
|
||||
self.device_idx,
|
||||
mmap=self.mmap,
|
||||
rt_flags=self.rt_flags,
|
||||
)
|
||||
self.shark_runner.iree_compilation_module = params["vmfb"]
|
||||
self.shark_runner.iree_config = params["config"]
|
||||
|
||||
@@ -72,6 +72,7 @@ class SharkRunner:
|
||||
extra_args: list = [],
|
||||
compile_vmfb: bool = True,
|
||||
device_idx: int = None,
|
||||
rt_flags: list = [],
|
||||
):
|
||||
self.mlir_module = mlir_module
|
||||
if self.mlir_module is not None:
|
||||
@@ -86,6 +87,7 @@ class SharkRunner:
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.device_idx = device_idx
|
||||
self.rt_flags = rt_flags
|
||||
|
||||
if check_device_drivers(self.device):
|
||||
print(device_driver_info(self.device))
|
||||
@@ -99,6 +101,7 @@ class SharkRunner:
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
device_idx=self.device_idx,
|
||||
rt_flags=self.rt_flags,
|
||||
compile_str=self.compile_str,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
|
||||
@@ -19,12 +19,15 @@ import json
|
||||
import os
|
||||
import psutil
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.parser import shark_args
|
||||
import iree.compiler as ireec
|
||||
|
||||
DEVICE = "cpu"
|
||||
PLATFORM_SHARK = "shark"
|
||||
@@ -63,12 +66,11 @@ def get_memory_info():
|
||||
return process.memory_info()
|
||||
|
||||
|
||||
def create_vmfb_module(
|
||||
def import_mlir_module(
|
||||
model_name: str,
|
||||
tokenizer,
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model.eval()
|
||||
@@ -87,6 +89,27 @@ def create_vmfb_module(
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
|
||||
def create_vmfb_module(
|
||||
model_name: str,
|
||||
tokenizer,
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
):
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
|
||||
# If MLIR has already been loaded and recompilation is not requested, use
|
||||
@@ -96,39 +119,32 @@ def create_vmfb_module(
|
||||
# compilation time can be correctly measured only when MLIR has already been
|
||||
# loaded.
|
||||
assert not recompile_shark or has_mlir
|
||||
if has_mlir:
|
||||
with open(mlir_path, "r") as f:
|
||||
model_mlir = f.read()
|
||||
print(f"Loaded .mlir from {mlir_path}")
|
||||
else:
|
||||
(model_mlir, func_name) = import_with_fx(
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
if not has_mlir:
|
||||
import_mlir_module(
|
||||
model_name,
|
||||
tokenizer,
|
||||
device,
|
||||
max_seq_len,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
f.write(model_mlir)
|
||||
print(f"Saved mlir at {mlir_path}")
|
||||
|
||||
shark_module = SharkInference(
|
||||
model_mlir,
|
||||
mlir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
is_benchmark=False,
|
||||
rt_flags=[],
|
||||
)
|
||||
|
||||
vmfb_name = (
|
||||
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
|
||||
)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
def load_shark_model(
|
||||
model_name: str, max_seq_len: int, recompile_shark: bool
|
||||
model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
plugin_path: str = [],
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
@@ -138,7 +154,13 @@ def load_shark_model(
|
||||
create_vmfb_module(
|
||||
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
|
||||
)
|
||||
shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
shark_module.load_module(vmfb_name)
|
||||
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
|
||||
|
||||
@@ -218,10 +240,13 @@ def collect_shark_logits(
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
plugin_path: str,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
|
||||
model_wrapper = load_shark_model(
|
||||
model_name, max_seq_len, recompile_shark, plugin_path
|
||||
)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Shark.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
@@ -318,6 +343,12 @@ def parse_args():
|
||||
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
|
||||
default=PLATFORM_SHARK,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin_path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
@@ -331,6 +362,7 @@ if __name__ == "__main__":
|
||||
args.max_seq_len,
|
||||
args.recompile_shark,
|
||||
args.save_json,
|
||||
args.plugin_path,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(shark_report)))
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user