mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 06:17:55 -05:00
Pipe through a debug option to iree compile utils. (#1796)
* Update compile_utils.py * Pipe through a flag to toggle debug options in compile utils. * Update SharkLLMBase.py
This commit is contained in:
@@ -46,6 +46,7 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -92,7 +93,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -444,6 +444,7 @@ class ShardedVicuna(VicunaBase):
|
||||
weight_group_size=128,
|
||||
compressed=False,
|
||||
extra_args_cmd=[],
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -454,6 +455,7 @@ class ShardedVicuna(VicunaBase):
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.config = config_json
|
||||
self.weight_group_size = weight_group_size
|
||||
@@ -641,7 +643,7 @@ class ShardedVicuna(VicunaBase):
|
||||
return device_idx
|
||||
|
||||
def compile_lmhead(
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None,
|
||||
):
|
||||
# compile the lm head of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
@@ -689,7 +691,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="lmhead")
|
||||
shark_module.save_module(module_name="lmhead", debug=self.debug)
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = LMHeadCompiled(shark_module)
|
||||
return compiled_module
|
||||
@@ -735,7 +737,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="norm")
|
||||
shark_module.save_module(module_name="norm", debug=self.debug)
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaNormCompiled(shark_module)
|
||||
return compiled_module
|
||||
@@ -786,14 +788,14 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="embedding")
|
||||
shark_module.save_module(module_name="embedding", debug=self.debug)
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaEmbeddingCompiled(shark_module)
|
||||
|
||||
return compiled_module
|
||||
|
||||
def compile_to_vmfb_one_model(
|
||||
self, inputs0, layers0, inputs1, layers1, device="cpu"
|
||||
self, inputs0, layers0, inputs1, layers1, device="cpu",
|
||||
):
|
||||
mlirs, modules = [], []
|
||||
assert len(layers0) == len(layers1)
|
||||
@@ -956,6 +958,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -1023,6 +1026,7 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -1659,6 +1663,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
@@ -3,7 +3,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -71,6 +71,7 @@ class Falcon(SharkLLMBase):
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_padding_length = 100
|
||||
@@ -78,6 +79,7 @@ class Falcon(SharkLLMBase):
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
@@ -208,6 +210,7 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -190,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -199,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -251,6 +251,7 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -294,6 +295,7 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,11 +32,13 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -111,7 +113,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -570,6 +570,14 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
|
||||
@@ -92,13 +92,27 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--iree-opt-strip-assertions=true",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
@@ -277,12 +291,13 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args()
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
|
||||
@@ -409,10 +424,11 @@ def get_iree_compiled_module(
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -468,10 +484,11 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
|
||||
@@ -192,7 +192,9 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
@@ -200,6 +202,7 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user