Pipe through a debug option to iree compile utils. (#1796)

* Update compile_utils.py

* Pipe through a flag to toggle debug options in compile utils.

* Update SharkLLMBase.py
This commit is contained in:
Ean Garvey
2023-08-25 09:11:11 -05:00
committed by GitHub
parent 450c231171
commit 9697981004
11 changed files with 64 additions and 20 deletions

View File

@@ -46,6 +46,7 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -92,7 +93,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
)
print("Saved vmfb at ", str(path))

View File

@@ -444,6 +444,7 @@ class ShardedVicuna(VicunaBase):
weight_group_size=128,
compressed=False,
extra_args_cmd=[],
debug=False,
) -> None:
super().__init__(
model_name,
@@ -454,6 +455,7 @@ class ShardedVicuna(VicunaBase):
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.weight_group_size = weight_group_size
@@ -641,7 +643,7 @@ class ShardedVicuna(VicunaBase):
return device_idx
def compile_lmhead(
self, lmh, hidden_states, device="cpu", device_idx=None
self, lmh, hidden_states, device="cpu", device_idx=None,
):
# compile the lm head of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
@@ -689,7 +691,7 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="lmhead")
shark_module.save_module(module_name="lmhead", debug=self.debug)
shark_module.load_module(vmfb_path)
compiled_module = LMHeadCompiled(shark_module)
return compiled_module
@@ -735,7 +737,7 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="norm")
shark_module.save_module(module_name="norm", debug=self.debug)
shark_module.load_module(vmfb_path)
compiled_module = VicunaNormCompiled(shark_module)
return compiled_module
@@ -786,14 +788,14 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="embedding")
shark_module.save_module(module_name="embedding", debug=self.debug)
shark_module.load_module(vmfb_path)
compiled_module = VicunaEmbeddingCompiled(shark_module)
return compiled_module
def compile_to_vmfb_one_model(
self, inputs0, layers0, inputs1, layers1, device="cpu"
self, inputs0, layers0, inputs1, layers1, device="cpu",
):
mlirs, modules = [], []
assert len(layers0) == len(layers1)
@@ -956,6 +958,7 @@ class ShardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
module.load_module(vmfb_path)
modules.append(module)
@@ -1023,6 +1026,7 @@ class ShardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
module.load_module(vmfb_path)
modules.append(module)
@@ -1659,6 +1663,7 @@ class UnshardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
print("Saved vic vmfb at ", str(path))
shark_module.load_module(path)

View File

@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self, model_name, hf_model_path=None, max_num_tokens=512
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -71,6 +71,7 @@ class Falcon(SharkLLMBase):
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -78,6 +79,7 @@ class Falcon(SharkLLMBase):
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
@@ -208,6 +210,7 @@ class Falcon(SharkLLMBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)

View File

@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
@@ -190,7 +190,7 @@ def compile_module(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
os.getcwd(), extended_model_name, extra_args, debug=debug
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -199,7 +199,7 @@ def compile_module(
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
):
torchscript_module = import_with_fx(
model,
@@ -251,6 +251,7 @@ def compile_int_precision(
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
debug=debug,
),
bytecode,
)
@@ -294,6 +295,7 @@ def shark_compile_through_fx_int(
device,
generate_or_load_vmfb,
extended_model_name,
debug,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",

View File

@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
)
print("Saved vmfb at ", str(path))

View File

@@ -570,6 +570,14 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,

View File

@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args
os.getcwd(), model_name, extra_args, debug=args.compile_debug
)
shark_module.load_module(path, extra_args=extra_args)
else:

View File

@@ -92,13 +92,27 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args():
return [
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--iree-opt-strip-assertions=true",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
@@ -277,12 +291,13 @@ def compile_module_to_flatbuffer(
model_config_path,
extra_args,
model_name="None",
debug=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args()
args += get_iree_common_args(debug=debug)
args += get_model_specific_args()
args += extra_args
@@ -409,10 +424,11 @@ def get_iree_compiled_module(
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
module, device, frontend, model_config_path, extra_args, debug
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -468,10 +484,11 @@ def export_iree_module_to_vmfb(
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args
module, device, mlir_dialect, model_config_path, extra_args, debug
)
if module_name is None:
device_name = (

View File

@@ -192,7 +192,9 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
def save_module(
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
@@ -200,6 +202,7 @@ class SharkInference:
self.mlir_dialect,
module_name=module_name,
extra_args=extra_args,
debug=debug,
)
# load and return the module.

View File

@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name)
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path