Switch most compile flows to use ireec.compile_file. (#1863)

* Switch most compile flows to use ireec.compile_file.

* re-add input type to compile_str path.

* Check if mlir_module exists before checking if it's a path or pyobject.

* Fix some save_dir cases
This commit is contained in:
Ean Garvey
2023-10-06 23:04:43 -05:00
committed by GitHub
parent 8614a18474
commit caf6cc5d8f
21 changed files with 184 additions and 79 deletions

View File

@@ -20,7 +20,7 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from apps.stable_diffusion.src import args
# Brevitas
@@ -256,6 +256,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):

View File

@@ -49,7 +49,7 @@ from apps.language_models.utils import (
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
@@ -672,9 +672,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"lmhead.mlir")
vmfb_path = Path(f"lmhead.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
@@ -699,12 +697,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"lmhead.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -724,9 +720,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"norm.mlir")
vmfb_path = Path(f"norm.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
hidden_states = torch_mlir.TensorPlaceholder.like(
hidden_states, dynamic_axes=[1]
@@ -745,12 +739,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"norm.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -770,9 +762,7 @@ class ShardedVicuna(VicunaBase):
mlir_path = Path(f"embedding.mlir")
vmfb_path = Path(f"embedding.vmfb")
if mlir_path.exists():
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
print(f"Found bytecode module at {mlir_path}.")
else:
input_ids = torch_mlir.TensorPlaceholder.like(
input_ids, dynamic_axes=[1]
@@ -796,12 +786,10 @@ class ShardedVicuna(VicunaBase):
filepath.absolute(),
single_file=True,
)
f_ = open(f"embedding.mlir", "rb")
bytecode = f_.read()
f_.close()
mlir_path = filepath
shark_module = SharkInference(
bytecode,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
device_idx=device_idx,
@@ -1474,8 +1462,7 @@ class UnshardedVicuna(VicunaBase):
)
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
combined_module = self.vicuna_mlir_path.absolute()
mlir_generated = True
break
@@ -1697,6 +1684,12 @@ class UnshardedVicuna(VicunaBase):
second_module,
self.vicuna_mlir_path,
)
combined_module = save_mlir(
combined_module,
model_name="combined_llama",
mlir_dialect="tm_tensor"
dir=self.vicuna_mlir_path,
)
del first_module, second_module
print(self.device)

View File

@@ -54,7 +54,6 @@ from apps.language_models.utils import (
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig

View File

@@ -7,7 +7,7 @@ from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import (
@@ -174,8 +174,6 @@ class Falcon(SharkLLMBase):
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
if not mlir_generated:
@@ -223,12 +221,15 @@ class Falcon(SharkLLMBase):
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path,
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",

View File

@@ -126,7 +126,7 @@ def is_url(input_url):
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -235,6 +235,12 @@ def compile_int_precision(
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
print(f"Elided IR written for {extended_model_name}")
bytecode = save_mlir(
bytecode,
model_name=extended_model_name,
frontend="torch",
dir=os.getcwd(),
)
return bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"

View File

@@ -710,8 +710,11 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
save_dir = ""
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,

View File

@@ -18,7 +18,7 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
@@ -154,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -168,6 +168,14 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -179,7 +187,6 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()

View File

@@ -292,9 +292,10 @@ def compile_module_to_flatbuffer(
extra_args,
model_name="None",
debug=False,
compile_str=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
input_type = "auto"
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args(debug=debug)
@@ -311,10 +312,7 @@ def compile_module_to_flatbuffer(
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
# TODO: make it simpler.
# Compile according to the input type, else just try compiling.
if input_type != "":
# Currently for MHLO/TOSA.
if compile_str:
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
@@ -322,9 +320,10 @@ def compile_module_to_flatbuffer(
input_type=input_type,
)
else:
# Currently for Torch.
flatbuffer_blob = ireec.compile_str(
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
module,
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
)
@@ -432,10 +431,17 @@ def get_iree_compiled_module(
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args, debug
module,
device,
frontend,
model_config_path,
extra_args,
debug,
compile_str,
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -492,10 +498,17 @@ def export_iree_module_to_vmfb(
module_name: str = None,
extra_args: list = [],
debug: bool = False,
compile_str: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args, debug
module,
device,
mlir_dialect,
model_config_path,
extra_args,
debug,
compile_str,
)
if module_name is None:
device_name = (

View File

@@ -84,6 +84,13 @@ class SharkBenchmarkRunner(SharkRunner):
self.extra_args = extra_args
self.import_args = {}
self.temp_file_to_unlink = None
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
SharkRunner.__init__(
self,
mlir_module,
@@ -98,6 +105,7 @@ class SharkBenchmarkRunner(SharkRunner):
".",
self.mlir_dialect,
extra_args=self.extra_args,
compile_str=self.compile_str,
)
params = load_flatbuffer(
self.vmfb_file,

View File

@@ -1,7 +1,7 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -130,10 +130,17 @@ def compile_int_precision(
mlir_module = mlir_module.encode("UTF-8")
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
bytecode_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
)
with open(bytecode_path, "wb") as f:
f.write(bytecode)
del bytecode
del mlir_module
print(f"Elided IR written for {extended_model_name}")
return bytecode
return bytecode_path
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
@@ -148,7 +155,7 @@ def compile_int_precision(
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode,
bytecode_path,
)
@@ -201,7 +208,7 @@ def shark_compile_through_fx(
]
else:
(
mlir_module,
bytecode,
_,
) = import_with_fx(
model=model,
@@ -212,6 +219,11 @@ def shark_compile_through_fx(
model_name=extended_model_name,
save_dir=save_dir,
)
mlir_module = save_mlir(
mlir_module=bytecode,
model_name=extended_model_name,
mlir_dialect=mlir_dialect,
)
shark_module = SharkInference(
mlir_module,

View File

@@ -275,11 +275,11 @@ def download_model(
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
mlir_filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
)
if not os.path.exists(filename):
if not os.path.exists(mlir_filename):
from tank.generate_sharktank import gen_shark_files
print(
@@ -287,13 +287,11 @@ def download_model(
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_file, function_name, inputs_tuple, golden_out_tuple
return mlir_filename, function_name, inputs_tuple, golden_out_tuple

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import torchvision.models as models
import copy
import io
@@ -20,10 +20,16 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
bytecode_path = save_mlir(
bytecode,
model_name="shark_eager_module",
frontend="torch",
mlir_dialect="tm_tensor",
)
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
mlir_module=bytecode_path,
device=device,
mlir_dialect="tm_tensor",
)

View File

@@ -3,8 +3,8 @@ import json
import numpy as np
import torch_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
from iree.compiler import compile_file
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
class GenerateConfigFile:
@@ -54,9 +54,15 @@ class GenerateConfigFile:
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
module_file = save_mlir(
module,
model_name="module_pre_split",
frontend="torch",
mlir_dialect="linalg",
)
compiled_module_str = str(
compile_str(
str(module),
compile_file(
module_file,
target_backends=[backend],
extra_args=[
"--compile-to=flow",

View File

@@ -749,3 +749,25 @@ def import_with_fx(
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
return mlir_module, func_name
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
def save_mlir(
mlir_module,
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)
return mlir_path

View File

@@ -39,7 +39,7 @@ class SharkInference:
Attributes
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -65,7 +65,7 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
mlir_module,
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
@@ -75,6 +75,14 @@ class SharkInference:
mmap: bool = True,
):
self.mlir_module = mlir_module
if mlir_module is not None:
if mlir_module and not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
@@ -203,6 +211,7 @@ class SharkInference:
module_name=module_name,
extra_args=extra_args,
debug=debug,
compile_str=self.compile_str,
)
# load and return the module.

View File

@@ -45,7 +45,7 @@ class SharkRunner:
Attributes
----------
mlir_module : str
mlir_module represented in string.
mlir_module path, string, or bytecode.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -74,6 +74,14 @@ class SharkRunner:
device_idx: int = None,
):
self.mlir_module = mlir_module
if self.mlir_module is not None:
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
@@ -91,6 +99,7 @@ class SharkRunner:
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
compile_str=self.compile_str,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]

View File

@@ -15,7 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
import numpy as np
from tqdm import tqdm
import sys
@@ -84,6 +84,12 @@ class SharkTrainer:
training=True,
mlir_type=mlir_type,
)
mlir_module = save_mlir(
mlir_module,
model_name="shark_model",
frontend="torch",
mlir_dialect=mlir_type,
)
self.shark_runner = SharkRunner(
mlir_module,
self.device,

View File

@@ -36,9 +36,7 @@ def create_module(model_name, tokenizer, device):
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
print(f"Found .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
@@ -50,9 +48,10 @@ def create_module(model_name, tokenizer, device):
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
del model_mlir
shark_module = SharkInference(
model_mlir,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,

View File

@@ -6,7 +6,7 @@ import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "facebook/opt-1.3b"
@@ -57,9 +57,10 @@ class OPTModuleTester:
with open(mlir_path, "w") as f:
f.write(mlir_module)
print(f"Saved mlir at {mlir_path}")
del mlir_module
shark_module = SharkInference(
mlir_module,
mlir_path,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=self.benchmark,

View File

@@ -2,7 +2,7 @@ import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
@@ -25,11 +25,13 @@ inputs = (
model=model,
inputs=inputs,
is_f16=False,
debug=True,
model_name=model_name.split("/")[1],
save_dir=".",
)
mlir_module = save_mlir(
mlir_module,
model_name=model_name.split("/")[1],
frontend="torch",
mlir_dialect="linalg",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",

View File

@@ -36,7 +36,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
get_hf_img_cls_model,
get_fp16_model,
)
from shark.shark_importer import import_with_fx
from shark.shark_importer import import_with_fx, save_mlir
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")