From caf6cc5d8f9432938ae8ecbf2e2dcbdd537efa22 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Fri, 6 Oct 2023 23:04:43 -0500 Subject: [PATCH] Switch most compile flows to use ireec.compile_file. (#1863) * Switch most compile flows to use ireec.compile_file. * re-add input type to compile_str path. * Check if mlir_module exists before checking if it's a path or pyobject. * Fix some save_dir cases --- .../langchain/h2oai_pipeline.py | 7 +++- apps/language_models/scripts/vicuna.py | 41 ++++++++----------- .../src/model_wrappers/vicuna4.py | 1 - .../src/pipelines/falcon_pipeline.py | 11 ++--- .../src/pipelines/minigpt4_pipeline.py | 8 +++- .../src/models/model_wrappers.py | 5 ++- apps/stable_diffusion/src/utils/utils.py | 13 ++++-- shark/iree_utils/compile_utils.py | 31 ++++++++++---- shark/shark_benchmark_runner.py | 8 ++++ shark/shark_compile.py | 22 +++++++--- shark/shark_downloader.py | 12 +++--- shark/shark_eager/shark_eager.py | 10 ++++- shark/shark_generate_model_config.py | 14 +++++-- shark/shark_importer.py | 22 ++++++++++ shark/shark_inference.py | 13 +++++- shark/shark_runner.py | 11 ++++- shark/shark_trainer.py | 8 +++- tank/examples/opt/opt_causallm.py | 7 ++-- tank/examples/opt/opt_causallm_torch_test.py | 5 ++- tank/examples/opt/shark_hf_base_opt.py | 12 +++--- tank/generate_sharktank.py | 2 +- 21 files changed, 184 insertions(+), 79 deletions(-) diff --git a/apps/language_models/langchain/h2oai_pipeline.py b/apps/language_models/langchain/h2oai_pipeline.py index 86f5c780..57c9f2b0 100644 --- a/apps/language_models/langchain/h2oai_pipeline.py +++ b/apps/language_models/langchain/h2oai_pipeline.py @@ -20,7 +20,7 @@ import gc from pathlib import Path from shark.shark_inference import SharkInference from shark.shark_downloader import download_public_file -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from apps.stable_diffusion.src import args # Brevitas @@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module): bytecode = bytecode_stream.getvalue() del module + bytecode = save_mlir( + bytecode, + model_name=f"h2ogpt_{precision}", + frontend="torch", + ) return bytecode def forward(self, input_ids, attention_mask): diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index e5803ab2..d79b0550 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -49,7 +49,7 @@ from apps.language_models.utils import ( ) from shark.shark_downloader import download_public_file from shark.shark_importer import get_f16_inputs -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from shark.shark_inference import SharkInference @@ -672,9 +672,7 @@ class ShardedVicuna(VicunaBase): mlir_path = Path(f"lmhead.mlir") vmfb_path = Path(f"lmhead.vmfb") if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() + print(f"Found bytecode module at {mlir_path}.") else: hidden_states = torch_mlir.TensorPlaceholder.like( hidden_states, dynamic_axes=[1] @@ -699,12 +697,10 @@ class ShardedVicuna(VicunaBase): filepath.absolute(), single_file=True, ) - f_ = open(f"lmhead.mlir", "rb") - bytecode = f_.read() - f_.close() + mlir_path = filepath shark_module = SharkInference( - bytecode, + mlir_path, device=device, mlir_dialect="tm_tensor", device_idx=device_idx, @@ -724,9 +720,7 @@ class ShardedVicuna(VicunaBase): mlir_path = Path(f"norm.mlir") vmfb_path = Path(f"norm.vmfb") if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() + print(f"Found bytecode module at {mlir_path}.") else: hidden_states = torch_mlir.TensorPlaceholder.like( hidden_states, dynamic_axes=[1] @@ -745,12 +739,10 @@ class ShardedVicuna(VicunaBase): filepath.absolute(), single_file=True, ) - f_ = open(f"norm.mlir", "rb") - bytecode = f_.read() - f_.close() + mlir_path = filepath shark_module = SharkInference( - bytecode, + mlir_path, device=device, mlir_dialect="tm_tensor", device_idx=device_idx, @@ -770,9 +762,7 @@ class ShardedVicuna(VicunaBase): mlir_path = Path(f"embedding.mlir") vmfb_path = Path(f"embedding.vmfb") if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() + print(f"Found bytecode module at {mlir_path}.") else: input_ids = torch_mlir.TensorPlaceholder.like( input_ids, dynamic_axes=[1] @@ -796,12 +786,10 @@ class ShardedVicuna(VicunaBase): filepath.absolute(), single_file=True, ) - f_ = open(f"embedding.mlir", "rb") - bytecode = f_.read() - f_.close() + mlir_path = filepath shark_module = SharkInference( - bytecode, + mlir_path, device=device, mlir_dialect="tm_tensor", device_idx=device_idx, @@ -1474,8 +1462,7 @@ class UnshardedVicuna(VicunaBase): ) if self.vicuna_mlir_path.exists(): print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}") - with open(self.vicuna_mlir_path, "rb") as f: - combined_module = f.read() + combined_module = self.vicuna_mlir_path.absolute() mlir_generated = True break @@ -1697,6 +1684,12 @@ class UnshardedVicuna(VicunaBase): second_module, self.vicuna_mlir_path, ) + combined_module = save_mlir( + combined_module, + model_name="combined_llama", + mlir_dialect="tm_tensor" + dir=self.vicuna_mlir_path, + ) del first_module, second_module print(self.device) diff --git a/apps/language_models/src/model_wrappers/vicuna4.py b/apps/language_models/src/model_wrappers/vicuna4.py index aa033a9f..10bef66f 100644 --- a/apps/language_models/src/model_wrappers/vicuna4.py +++ b/apps/language_models/src/model_wrappers/vicuna4.py @@ -54,7 +54,6 @@ from apps.language_models.utils import ( ) from shark.shark_downloader import download_public_file from shark.shark_importer import get_f16_inputs -from shark.shark_importer import import_with_fx from shark.shark_inference import SharkInference from transformers.models.llama.configuration_llama import LlamaConfig diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index 94c754d0..1767b629 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -7,7 +7,7 @@ from io import BytesIO from pathlib import Path from contextlib import redirect_stdout from shark.shark_downloader import download_public_file -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from shark.shark_inference import SharkInference from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.generation import ( @@ -174,8 +174,6 @@ class Falcon(SharkLLMBase): print( f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}" ) - with open(self.falcon_mlir_path, "rb") as f: - bytecode = f.read() mlir_generated = True if not mlir_generated: @@ -223,12 +221,15 @@ class Falcon(SharkLLMBase): f_.write(bytecode) print("Saved falcon mlir at ", str(self.falcon_mlir_path)) f_.close() + del bytecode shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="linalg" + mlir_module=self.falcon_mlir_path, + device=self.device, + mlir_dialect="linalg", ) path = shark_module.save_module( - self.falcon_vmfb_path.parent.absolute(), + self.falcon_vmfb_path, self.falcon_vmfb_path.stem, extra_args=[ "--iree-vm-target-truncate-unsupported-floats", diff --git a/apps/language_models/src/pipelines/minigpt4_pipeline.py b/apps/language_models/src/pipelines/minigpt4_pipeline.py index cf7f06a5..5d5726a8 100644 --- a/apps/language_models/src/pipelines/minigpt4_pipeline.py +++ b/apps/language_models/src/pipelines/minigpt4_pipeline.py @@ -126,7 +126,7 @@ def is_url(input_url): import os import tempfile from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir import torch import torch_mlir from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -235,6 +235,12 @@ def compile_int_precision( mlir_module = BytesIO(mlir_module) bytecode = mlir_module.read() print(f"Elided IR written for {extended_model_name}") + bytecode = save_mlir( + bytecode, + model_name=extended_model_name, + frontend="torch", + dir=os.getcwd(), + ) return bytecode shark_module = SharkInference( mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 16835917..2ef9f696 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -710,8 +710,11 @@ class SharkifyStableDiffusionModel: return self.text_encoder(input)[0] clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage) - save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"]) + save_dir = "" if self.debug: + save_dir = os.path.join( + self.sharktank_dir, self.model_name["clip"] + ) os.makedirs( save_dir, exist_ok=True, diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 41b9fb07..d97c6b85 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -18,7 +18,7 @@ import tempfile import torch from safetensors.torch import load_file from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from shark.iree_utils.vulkan_utils import ( set_iree_vulkan_runtime_flags, get_vulkan_target_triple, @@ -154,8 +154,8 @@ def compile_through_fx( f16_input_mask=f16_input_mask, debug=debug, model_name=extended_model_name, - save_dir=save_dir, ) + if use_tuned: if "vae" in extended_model_name.split("_")[0]: args.annotation_model = "vae" @@ -168,6 +168,14 @@ def compile_through_fx( mlir_module, extended_model_name, base_model_id ) + if not os.path.isdir(save_dir): + save_dir = "" + + mlir_module = save_mlir( + mlir_module, + model_name=extended_model_name, + dir=save_dir, + ) shark_module = SharkInference( mlir_module, device=args.device if device is None else device, @@ -179,7 +187,6 @@ def compile_through_fx( mlir_module, ) - del mlir_module gc.collect() diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 6dc545c8..8910d547 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -292,9 +292,10 @@ def compile_module_to_flatbuffer( extra_args, model_name="None", debug=False, + compile_str=False, ): # Setup Compile arguments wrt to frontends. - input_type = "" + input_type = "auto" args = get_iree_frontend_args(frontend) args += get_iree_device_args(device, extra_args) args += get_iree_common_args(debug=debug) @@ -311,10 +312,7 @@ def compile_module_to_flatbuffer( elif frontend in ["tm_tensor"]: input_type = ireec.InputType.TM_TENSOR - # TODO: make it simpler. - # Compile according to the input type, else just try compiling. - if input_type != "": - # Currently for MHLO/TOSA. + if compile_str: flatbuffer_blob = ireec.compile_str( module, target_backends=[iree_target_map(device)], @@ -322,9 +320,10 @@ def compile_module_to_flatbuffer( input_type=input_type, ) else: - # Currently for Torch. - flatbuffer_blob = ireec.compile_str( + assert os.path.isfile(module) + flatbuffer_blob = ireec.compile_file( module, + input_type=input_type, target_backends=[iree_target_map(device)], extra_args=args, ) @@ -432,10 +431,17 @@ def get_iree_compiled_module( device_idx: int = None, mmap: bool = False, debug: bool = False, + compile_str: bool = False, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( - module, device, frontend, model_config_path, extra_args, debug + module, + device, + frontend, + model_config_path, + extra_args, + debug, + compile_str, ) temp_file_to_unlink = None # TODO: Currently mmap=True control flow path has been switched off for mmap. @@ -492,10 +498,17 @@ def export_iree_module_to_vmfb( module_name: str = None, extra_args: list = [], debug: bool = False, + compile_str: bool = False, ): # Compiles the module given specs and saves it as .vmfb file. flatbuffer_blob = compile_module_to_flatbuffer( - module, device, mlir_dialect, model_config_path, extra_args, debug + module, + device, + mlir_dialect, + model_config_path, + extra_args, + debug, + compile_str, ) if module_name is None: device_name = ( diff --git a/shark/shark_benchmark_runner.py b/shark/shark_benchmark_runner.py index 88fa9be6..4d87d788 100644 --- a/shark/shark_benchmark_runner.py +++ b/shark/shark_benchmark_runner.py @@ -84,6 +84,13 @@ class SharkBenchmarkRunner(SharkRunner): self.extra_args = extra_args self.import_args = {} self.temp_file_to_unlink = None + if not os.path.isfile(mlir_module): + print( + "Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." + ) + self.compile_str = True + else: + self.compile_str = False SharkRunner.__init__( self, mlir_module, @@ -98,6 +105,7 @@ class SharkBenchmarkRunner(SharkRunner): ".", self.mlir_dialect, extra_args=self.extra_args, + compile_str=self.compile_str, ) params = load_flatbuffer( self.vmfb_file, diff --git a/shark/shark_compile.py b/shark/shark_compile.py index ba21805c..661af47f 100644 --- a/shark/shark_compile.py +++ b/shark/shark_compile.py @@ -1,7 +1,7 @@ import os import tempfile from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir import torch import torch_mlir from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -130,10 +130,17 @@ def compile_int_precision( mlir_module = mlir_module.encode("UTF-8") mlir_module = BytesIO(mlir_module) bytecode = mlir_module.read() + bytecode_path = os.path.join( + os.getcwd(), f"{extended_model_name}_linalg.mlirbc" + ) + with open(bytecode_path, "wb") as f: + f.write(bytecode) + del bytecode + del mlir_module print(f"Elided IR written for {extended_model_name}") - return bytecode + return bytecode_path shark_module = SharkInference( - mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" + mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor" ) extra_args = [ "--iree-hal-dump-executable-sources-to=ies", @@ -148,7 +155,7 @@ def compile_int_precision( generate_vmfb=generate_vmfb, extra_args=extra_args, ), - bytecode, + bytecode_path, ) @@ -201,7 +208,7 @@ def shark_compile_through_fx( ] else: ( - mlir_module, + bytecode, _, ) = import_with_fx( model=model, @@ -212,6 +219,11 @@ def shark_compile_through_fx( model_name=extended_model_name, save_dir=save_dir, ) + mlir_module = save_mlir( + mlir_module=bytecode, + model_name=extended_model_name, + mlir_dialect=mlir_dialect, + ) shark_module = SharkInference( mlir_module, diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 633d752e..b9baf325 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -275,11 +275,11 @@ def download_model( model_dir = os.path.join(WORKDIR, model_dir_name) tuned_str = "" if tuned is None else "_" + tuned suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir" - filename = os.path.join(model_dir, model_name + suffix) + mlir_filename = os.path.join(model_dir, model_name + suffix) print( - f"Verifying that model artifacts were downloaded successfully to {filename}..." + f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..." ) - if not os.path.exists(filename): + if not os.path.exists(mlir_filename): from tank.generate_sharktank import gen_shark_files print( @@ -287,13 +287,11 @@ def download_model( ) gen_shark_files(model_name, frontend, WORKDIR, import_args) - assert os.path.exists(filename), f"MLIR not found at {filename}" - with open(filename, mode="rb") as f: - mlir_file = f.read() + assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}" function_name = str(np.load(os.path.join(model_dir, "function_name.npy"))) inputs = np.load(os.path.join(model_dir, "inputs.npz")) golden_out = np.load(os.path.join(model_dir, "golden_out.npz")) inputs_tuple = tuple([inputs[key] for key in inputs]) golden_out_tuple = tuple([golden_out[key] for key in golden_out]) - return mlir_file, function_name, inputs_tuple, golden_out_tuple + return mlir_filename, function_name, inputs_tuple, golden_out_tuple diff --git a/shark/shark_eager/shark_eager.py b/shark/shark_eager/shark_eager.py index 807c6c01..bd751199 100644 --- a/shark/shark_eager/shark_eager.py +++ b/shark/shark_eager/shark_eager.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Tuple from collections import defaultdict -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir import torchvision.models as models import copy import io @@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"): bytecode_stream = io.BytesIO() mlir_module.operation.write_bytecode(bytecode_stream) bytecode = bytecode_stream.getvalue() + bytecode_path = save_mlir( + bytecode, + model_name="shark_eager_module", + frontend="torch", + mlir_dialect="tm_tensor", + ) from shark.shark_inference import SharkInference shark_module = SharkInference( - mlir_module=bytecode, + mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor", ) diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py index e177dccc..9847b116 100644 --- a/shark/shark_generate_model_config.py +++ b/shark/shark_generate_model_config.py @@ -3,8 +3,8 @@ import json import numpy as np import torch_mlir -from iree.compiler import compile_str -from shark.shark_importer import import_with_fx, get_f16_inputs +from iree.compiler import compile_file +from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir class GenerateConfigFile: @@ -54,9 +54,15 @@ class GenerateConfigFile: verbose=False, ) module = module.operation.get_asm(large_elements_limit=4) + module_file = save_mlir( + module, + model_name="module_pre_split", + frontend="torch", + mlir_dialect="linalg", + ) compiled_module_str = str( - compile_str( - str(module), + compile_file( + module_file, target_backends=[backend], extra_args=[ "--compile-to=flow", diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 498e218b..2722f6cb 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -749,3 +749,25 @@ def import_with_fx( mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type) return mlir_module, func_name + + +# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file. +def save_mlir( + mlir_module, + model_name, + mlir_dialect="linalg", + frontend="torch", + dir=tempfile.gettempdir(), +): + model_name_mlir = ( + model_name + "_" + frontend + "_" + mlir_dialect + ".mlir" + ) + if dir == "": + dir = tempfile.gettempdir() + mlir_path = os.path.join(dir, model_name_mlir) + print(f"saving {model_name_mlir} to {dir}") + if frontend == "torch": + with open(mlir_path, "wb") as mlir_file: + mlir_file.write(mlir_module) + + return mlir_path diff --git a/shark/shark_inference.py b/shark/shark_inference.py index 6f942b73..f2e14ee7 100644 --- a/shark/shark_inference.py +++ b/shark/shark_inference.py @@ -39,7 +39,7 @@ class SharkInference: Attributes ---------- mlir_module : str - mlir_module represented in string; modules from torch-mlir are serialized in bytecode format. + mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format. device : str device to execute the mlir_module on. currently supports cpu, cuda, vulkan, and metal backends. @@ -65,7 +65,7 @@ class SharkInference: def __init__( self, - mlir_module: bytes, + mlir_module, device: str = "none", mlir_dialect: str = "linalg", is_benchmark: bool = False, @@ -75,6 +75,14 @@ class SharkInference: mmap: bool = True, ): self.mlir_module = mlir_module + if mlir_module is not None: + if mlir_module and not os.path.isfile(mlir_module): + print( + "Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." + ) + self.compile_str = True + else: + self.compile_str = False self.device = shark_args.device if device == "none" else device self.mlir_dialect = mlir_dialect self.is_benchmark = is_benchmark @@ -203,6 +211,7 @@ class SharkInference: module_name=module_name, extra_args=extra_args, debug=debug, + compile_str=self.compile_str, ) # load and return the module. diff --git a/shark/shark_runner.py b/shark/shark_runner.py index 2552dd6a..2eec2b97 100644 --- a/shark/shark_runner.py +++ b/shark/shark_runner.py @@ -45,7 +45,7 @@ class SharkRunner: Attributes ---------- mlir_module : str - mlir_module represented in string. + mlir_module path, string, or bytecode. device : str device to execute the mlir_module on. currently supports cpu, cuda, vulkan, and metal backends. @@ -74,6 +74,14 @@ class SharkRunner: device_idx: int = None, ): self.mlir_module = mlir_module + if self.mlir_module is not None: + if not os.path.isfile(mlir_module): + print( + "Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead." + ) + self.compile_str = True + else: + self.compile_str = False self.device = shark_args.device if device == "none" else device self.mlir_dialect = mlir_dialect self.extra_args = extra_args @@ -91,6 +99,7 @@ class SharkRunner: self.mlir_dialect, extra_args=self.extra_args, device_idx=self.device_idx, + compile_str=self.compile_str, ) self.iree_compilation_module = params["vmfb"] self.iree_config = params["config"] diff --git a/shark/shark_trainer.py b/shark/shark_trainer.py index 36916f24..16bdd984 100644 --- a/shark/shark_trainer.py +++ b/shark/shark_trainer.py @@ -15,7 +15,7 @@ from shark.parser import shark_args from shark.shark_runner import SharkRunner from shark.backward_makefx import MakeFxModule -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir import numpy as np from tqdm import tqdm import sys @@ -84,6 +84,12 @@ class SharkTrainer: training=True, mlir_type=mlir_type, ) + mlir_module = save_mlir( + mlir_module, + model_name="shark_model", + frontend="torch", + mlir_dialect=mlir_type, + ) self.shark_runner = SharkRunner( mlir_module, self.device, diff --git a/tank/examples/opt/opt_causallm.py b/tank/examples/opt/opt_causallm.py index 8db0fe54..210671c5 100644 --- a/tank/examples/opt/opt_causallm.py +++ b/tank/examples/opt/opt_causallm.py @@ -36,9 +36,7 @@ def create_module(model_name, tokenizer, device): mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir" if os.path.isfile(mlir_path): - with open(mlir_path, "r") as f: - model_mlir = f.read() - print(f"Loaded .mlir from {mlir_path}") + print(f"Found .mlir from {mlir_path}") else: (model_mlir, func_name) = import_with_fx( model=opt_model, @@ -50,9 +48,10 @@ def create_module(model_name, tokenizer, device): with open(mlir_path, "w") as f: f.write(model_mlir) print(f"Saved mlir at {mlir_path}") + del model_mlir shark_module = SharkInference( - model_mlir, + mlir_path, device=device, mlir_dialect="tm_tensor", is_benchmark=False, diff --git a/tank/examples/opt/opt_causallm_torch_test.py b/tank/examples/opt/opt_causallm_torch_test.py index c57cebe8..17cc1a93 100644 --- a/tank/examples/opt/opt_causallm_torch_test.py +++ b/tank/examples/opt/opt_causallm_torch_test.py @@ -6,7 +6,7 @@ import numpy as np from shark_opt_wrapper import OPTForCausalLMModel from shark.iree_utils._common import check_device_drivers, device_driver_info from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from transformers import AutoTokenizer, OPTForCausalLM OPT_MODEL = "facebook/opt-1.3b" @@ -57,9 +57,10 @@ class OPTModuleTester: with open(mlir_path, "w") as f: f.write(mlir_module) print(f"Saved mlir at {mlir_path}") + del mlir_module shark_module = SharkInference( - mlir_module, + mlir_path, device=device, mlir_dialect="tm_tensor", is_benchmark=self.benchmark, diff --git a/tank/examples/opt/shark_hf_base_opt.py b/tank/examples/opt/shark_hf_base_opt.py index 8acf374a..ef41b8d3 100644 --- a/tank/examples/opt/shark_hf_base_opt.py +++ b/tank/examples/opt/shark_hf_base_opt.py @@ -2,7 +2,7 @@ import os import torch from transformers import AutoTokenizer, OPTForCausalLM from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx +from shark.shark_importer import import_with_fx, save_mlir from shark_opt_wrapper import OPTForCausalLMModel model_name = "facebook/opt-1.3b" @@ -25,11 +25,13 @@ inputs = ( model=model, inputs=inputs, is_f16=False, - debug=True, - model_name=model_name.split("/")[1], - save_dir=".", ) - +mlir_module = save_mlir( + mlir_module, + model_name=model_name.split("/")[1], + frontend="torch", + mlir_dialect="linalg", +) shark_module = SharkInference( mlir_module, device="cpu-sync", diff --git a/tank/generate_sharktank.py b/tank/generate_sharktank.py index d4f08ceb..67f18281 100644 --- a/tank/generate_sharktank.py +++ b/tank/generate_sharktank.py @@ -36,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args): get_hf_img_cls_model, get_fp16_model, ) - from shark.shark_importer import import_with_fx + from shark.shark_importer import import_with_fx, save_mlir with open(torch_model_list) as csvfile: torch_reader = csv.reader(csvfile, delimiter=",")