Compare commits

..

14 Commits

Author SHA1 Message Date
Ean Garvey
a638d1d5d4 fix quant imports 2023-08-24 23:28:09 -05:00
Ean Garvey
8298865bda Merge branch 'main' into msvc-rocm 2023-08-24 23:25:34 -05:00
Ean Garvey
b086bf7d4f Merge branch 'main' into msvc-rocm 2023-08-24 01:04:39 -05:00
Ean Garvey
e644fdf38a Fix formatting and regex. 2023-08-16 14:43:08 -05:00
Ean Garvey
ac01cfa5cc Update stable_args.py 2023-08-16 13:28:39 -05:00
Ean Garvey
c22416cbb5 Guard quantization imports 2023-08-16 13:26:51 -05:00
Ean Garvey
7d77d6cfb2 Update rocm arg handling in SD utils 2023-08-16 13:23:37 -05:00
Ean Garvey
c9cdc8f3c7 Update stable_args.py 2023-08-16 13:14:25 -05:00
Ean Garvey
3f33ea0f46 Make get_iree_rocm_args platform-agnostic. 2023-08-16 13:13:37 -05:00
Ean Garvey
5916e1c89e Update benchmark_utils.py 2023-08-16 12:57:33 -05:00
Ean Garvey
5954a0563c Update _common.py 2023-08-16 12:54:21 -05:00
Ean Garvey
c73b805719 Fix brevitas imports 2023-08-16 12:50:50 -05:00
Ean Garvey
0d787c7c80 Delete opt_model.py 2023-08-16 12:48:40 -05:00
Ean Garvey
6f05a8b934 WIP: MSVC ROCM support for SHARK Studio 2023-08-16 00:13:19 -05:00
19 changed files with 307 additions and 988 deletions

View File

@@ -237,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:

View File

@@ -46,7 +46,6 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -93,7 +92,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))

View File

@@ -39,7 +39,6 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
SecondVicuna13B,
SecondVicuna70B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
@@ -192,6 +191,7 @@ class VicunaBase(SharkLLMBase):
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
model_name=None,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
@@ -355,7 +355,8 @@ class VicunaBase(SharkLLMBase):
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
if not (model_name and "llama2_13b" in model_name):
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
@@ -441,7 +442,6 @@ class ShardedVicuna(VicunaBase):
weight_group_size=128,
compressed=False,
extra_args_cmd=[],
debug=False,
) -> None:
super().__init__(
model_name,
@@ -452,7 +452,6 @@ class ShardedVicuna(VicunaBase):
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.weight_group_size = weight_group_size
@@ -640,7 +639,7 @@ class ShardedVicuna(VicunaBase):
return device_idx
def compile_lmhead(
self, lmh, hidden_states, device="cpu", device_idx=None,
self, lmh, hidden_states, device="cpu", device_idx=None
):
# compile the lm head of the vicuna model
# This can be used for both first and second vicuna, so only needs to be run once
@@ -688,7 +687,7 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="lmhead", debug=self.debug)
shark_module.save_module(module_name="lmhead")
shark_module.load_module(vmfb_path)
compiled_module = LMHeadCompiled(shark_module)
return compiled_module
@@ -734,7 +733,7 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="norm", debug=self.debug)
shark_module.save_module(module_name="norm")
shark_module.load_module(vmfb_path)
compiled_module = VicunaNormCompiled(shark_module)
return compiled_module
@@ -785,14 +784,14 @@ class ShardedVicuna(VicunaBase):
if vmfb_path.exists():
shark_module.load_module(vmfb_path)
else:
shark_module.save_module(module_name="embedding", debug=self.debug)
shark_module.save_module(module_name="embedding")
shark_module.load_module(vmfb_path)
compiled_module = VicunaEmbeddingCompiled(shark_module)
return compiled_module
def compile_to_vmfb_one_model(
self, inputs0, layers0, inputs1, layers1, device="cpu",
self, inputs0, layers0, inputs1, layers1, device="cpu"
):
mlirs, modules = [], []
assert len(layers0) == len(layers1)
@@ -802,6 +801,7 @@ class ShardedVicuna(VicunaBase):
# if vmfb_path.exists():
# continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
@@ -855,7 +855,7 @@ class ShardedVicuna(VicunaBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module0,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -899,7 +899,7 @@ class ShardedVicuna(VicunaBase):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module1,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -924,6 +924,7 @@ class ShardedVicuna(VicunaBase):
mlirs.append(module_combined)
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
@@ -955,7 +956,6 @@ class ShardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
module.load_module(vmfb_path)
modules.append(module)
@@ -972,6 +972,7 @@ class ShardedVicuna(VicunaBase):
# if vmfb_path.exists():
# continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
@@ -990,6 +991,7 @@ class ShardedVicuna(VicunaBase):
mlirs.append(bytecode)
if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
@@ -1021,7 +1023,6 @@ class ShardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
module.load_module(vmfb_path)
modules.append(module)
@@ -1230,13 +1231,12 @@ class UnshardedVicuna(VicunaBase):
precision="int8",
vicuna_mlir_path=None,
vicuna_vmfb_path=None,
load_mlir_from_shark_tank=False,
load_mlir_from_shark_tank=True,
low_device_memory=False,
weight_group_size=128,
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
debug=False,
) -> None:
super().__init__(
model_name,
@@ -1265,7 +1265,6 @@ class UnshardedVicuna(VicunaBase):
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
self.low_device_memory = low_device_memory
self.weight_group_size = weight_group_size
self.debug = debug
if self.vicuna_mlir_path == None:
self.vicuna_mlir_path = self.get_model_path()
if self.vicuna_vmfb_path == None:
@@ -1276,7 +1275,7 @@ class UnshardedVicuna(VicunaBase):
def get_model_path(self, suffix="mlir"):
safe_device = self.device.split("-")[0]
if suffix in ["mlirbc", "mlir"]:
if suffix == "mlir":
return Path(f"{self.model_name}_{self.precision}.{suffix}")
return Path(
f"{self.model_name}_{self.precision}_{safe_device}.{suffix}"
@@ -1336,7 +1335,7 @@ class UnshardedVicuna(VicunaBase):
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(self, module):
def write_in_dynamic_inputs1(self, module, model_name):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
@@ -1364,12 +1363,9 @@ class UnshardedVicuna(VicunaBase):
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in self.model_name:
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in self.model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if self.precision in ["fp16", "int4", "int8"]:
@@ -1404,13 +1400,13 @@ class UnshardedVicuna(VicunaBase):
return "\n".join(new_lines)
def compile(self):
def compile(self, download_vmfb=False):
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
if not self.vicuna_vmfb_path.exists() and self.download_vmfb:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
)
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
)
if not self.vicuna_vmfb_path.exists() and download_vmfb:
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
self.vicuna_vmfb_path.absolute(),
@@ -1423,237 +1419,233 @@ class UnshardedVicuna(VicunaBase):
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
return
print(f"[DEBUG] vmfb not found")
mlir_generated = False
for suffix in ["mlirbc", "mlir"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
print(
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
break
if not mlir_generated:
print(f"[DEBUG] mlir not found")
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
else:
compilation_prompt = "".join(["0" for _ in range(17)])
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
if Path(first_model_path).exists():
print(f"loading {first_model_path}")
with open(Path(first_model_path), "r") as f:
first_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del firstVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated first vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(first_model_path, "w+") as f:
f.write(first_module)
print("Finished writing IR after dynamic")
print(f"[DEBUG] Starting generation of second llama")
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
if Path(second_model_path).exists():
print(f"loading {second_model_path}")
with open(Path(second_model_path), "r") as f:
second_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
if self.model_name == "llama2_13b":
dim1 = 40
total_tuple = 80
elif self.model_name == "llama2_70b":
dim1 = 8
total_tuple = 160
else:
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False] + [True] * total_tuple,
mlir_type="torchscript",
)
del model
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * total_tuple,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated second vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(second_model_path, "w+") as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module,
second_module,
self.vicuna_mlir_path,
print(f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}")
if self.vicuna_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
else:
print(
f"[DEBUG] mlir not found at {self.vicuna_mlir_path.absolute()}"
)
del first_module, second_module
mlir_generated = False
if self.load_mlir_from_shark_tank:
# download MLIR from shark tank
for suffix in ["mlir", "mlirbc"]:
self.vicuna_mlir_path = self.get_model_path(suffix)
download_public_file(
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
self.vicuna_mlir_path.absolute(),
single_file=True,
)
if self.vicuna_mlir_path.exists():
with open(self.vicuna_mlir_path, "rb") as f:
combined_module = f.read()
mlir_generated = True
break
self.vicuna_mlir_path = self.get_model_path("mlir")
if not mlir_generated:
print(
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
)
if not mlir_generated:
print("[DEBUG] generating mlir on device")
# Select a compilation prompt such that the resulting input_ids
# from the model's tokenizer has shape [1, 19]
if self.model_name == "codegen":
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
else:
compilation_prompt = "".join(["0" for _ in range(17)])
if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
first_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del firstVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated first vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)
print("Finished writing IR after dynamic")
if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
second_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
if self.model_name == "llama2_13b":
dim1 = 40
total_tuple = 80
else:
dim1 = 32
total_tuple = 64
pkv = tuple(
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
for _ in range(total_tuple)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
if self.model_name == "llama2_13b":
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False] + [True] * total_tuple,
mlir_type="torchscript",
)
del model
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * total_tuple,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
run_pipeline_with_repro_report(
second_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated second vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(f"second_{self.precision}.mlir", "w") as f:
f.write(second_module)
print("Finished writing IR after dynamic")
combined_module = self.combine_mlir_scripts(
first_module,
second_module,
self.vicuna_mlir_path,
self.model_name,
)
del first_module, second_module
print(self.device)
if "rocm" in self.device:
@@ -1672,7 +1664,6 @@ class UnshardedVicuna(VicunaBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ self.extra_args,
debug=self.debug,
)
print("Saved vic vmfb at ", str(path))
shark_module.load_module(path)
@@ -1740,6 +1731,7 @@ class UnshardedVicuna(VicunaBase):
yield detok, ""
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
yield res_str, "formatted"
def autocomplete(self, prompt):

View File

@@ -1,94 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -18,7 +18,6 @@ class FirstVicuna(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
@@ -67,7 +66,6 @@ class SecondVicuna7B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
@@ -157,6 +155,8 @@ class SecondVicuna7B(torch.nn.Module):
i63,
i64,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
@@ -301,7 +301,7 @@ class SecondVicuna13B(torch.nn.Module):
def __init__(
self,
model_path,
precision="int8",
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -418,6 +418,8 @@ class SecondVicuna13B(torch.nn.Module):
i79,
i80,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
@@ -590,540 +592,6 @@ class SecondVicuna13B(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna70B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
i81,
i82,
i83,
i84,
i85,
i86,
i87,
i88,
i89,
i90,
i91,
i92,
i93,
i94,
i95,
i96,
i97,
i98,
i99,
i100,
i101,
i102,
i103,
i104,
i105,
i106,
i107,
i108,
i109,
i110,
i111,
i112,
i113,
i114,
i115,
i116,
i117,
i118,
i119,
i120,
i121,
i122,
i123,
i124,
i125,
i126,
i127,
i128,
i129,
i130,
i131,
i132,
i133,
i134,
i135,
i136,
i137,
i138,
i139,
i140,
i141,
i142,
i143,
i144,
i145,
i146,
i147,
i148,
i149,
i150,
i151,
i152,
i153,
i154,
i155,
i156,
i157,
i158,
i159,
i160,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
(
i81,
i82,
),
(
i83,
i84,
),
(
i85,
i86,
),
(
i87,
i88,
),
(
i89,
i90,
),
(
i91,
i92,
),
(
i93,
i94,
),
(
i95,
i96,
),
(
i97,
i98,
),
(
i99,
i100,
),
(
i101,
i102,
),
(
i103,
i104,
),
(
i105,
i106,
),
(
i107,
i108,
),
(
i109,
i110,
),
(
i111,
i112,
),
(
i113,
i114,
),
(
i115,
i116,
),
(
i117,
i118,
),
(
i119,
i120,
),
(
i121,
i122,
),
(
i123,
i124,
),
(
i125,
i126,
),
(
i127,
i128,
),
(
i129,
i130,
),
(
i131,
i132,
),
(
i133,
i134,
),
(
i135,
i136,
),
(
i137,
i138,
),
(
i139,
i140,
),
(
i141,
i142,
),
(
i143,
i144,
),
(
i145,
i146,
),
(
i147,
i148,
),
(
i149,
i150,
),
(
i151,
i152,
),
(
i153,
i154,
),
(
i155,
i156,
),
(
i157,
i158,
),
(
i159,
i160,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0])
return_vals.append(item[1])
return tuple(return_vals)
class CombinedModel(torch.nn.Module):
def __init__(
self,

View File

@@ -3,10 +3,7 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
self, model_name, hf_model_path=None, max_num_tokens=512
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -71,7 +71,6 @@ class Falcon(SharkLLMBase):
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_padding_length = 100
@@ -79,7 +78,6 @@ class Falcon(SharkLLMBase):
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
@@ -210,7 +208,6 @@ class Falcon(SharkLLMBase):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-spirv-index-bits=64",
],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)

View File

@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
@@ -190,7 +190,7 @@ def compile_module(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args, debug=debug
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -199,7 +199,7 @@ def compile_module(
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
model, inputs, precision, device, generate_vmfb, extended_model_name
):
torchscript_module = import_with_fx(
model,
@@ -219,7 +219,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -251,7 +251,6 @@ def compile_int_precision(
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
debug=debug,
),
bytecode,
)
@@ -295,7 +294,6 @@ def shark_compile_through_fx_int(
device,
generate_or_load_vmfb,
extended_model_name,
debug,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",

View File

@@ -32,13 +32,11 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -113,7 +111,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))

View File

@@ -74,11 +74,8 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
x for x in collect_submodules("transformers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree._runtime_libs"]

View File

@@ -570,14 +570,6 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,

View File

@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args, debug=args.compile_debug
os.getcwd(), model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:

View File

@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",

View File

@@ -160,15 +160,14 @@ def chat(
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global past_key_values
global model_vmfb_key
global vicuna_model
global vicuna_model
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
@@ -178,8 +177,6 @@ def chat(
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
@@ -197,6 +194,20 @@ def chat(
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
if new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
@@ -228,8 +239,6 @@ def chat(
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
)
# else:
@@ -353,8 +362,6 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
)
# TODO: add role dict for different models
@@ -425,14 +432,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
# print(supported_devices)
devices = gr.Dropdown(
label="Device",
value=supported_devices[0]
if enabled
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
# multiselect=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
@@ -444,13 +452,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
],
visible=True,
)
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Row(visible=False):
with gr.Group():
@@ -485,15 +487,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot, tokens_time],
queue=True,
)
@@ -501,15 +495,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[
system_msg,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
inputs=[system_msg, chatbot, model, devices, precision, config_file],
outputs=[chatbot, tokens_time],
queue=True,
)

View File

@@ -7,13 +7,22 @@ import fileinput
from pathlib import Path
# Temporary workaround for transformers/__init__.py.
path_to_stdhooks = Path(
get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks"
)
path_to_transformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
str(path_to_stdhooks) + "hook-transformers.py"
)
if path_to_transformers_hook.is_file():
pass
else:
if not path_to_stdhooks.is_dir():
import os
print(
f"Path to pyinstaller stdhooks not found. Please check your pyinstaller packages at {path_to_stdhooks}."
)
os.mkdir(path_to_stdhooks)
with open(path_to_transformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")

View File

@@ -92,27 +92,13 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
def get_iree_common_args():
return [
"--iree-stream-resource-max-allocation-size=4294967295",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--iree-opt-strip-assertions=true",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
@@ -291,13 +277,12 @@ def compile_module_to_flatbuffer(
model_config_path,
extra_args,
model_name="None",
debug=False,
):
# Setup Compile arguments wrt to frontends.
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args(debug=debug)
args += get_iree_common_args()
args += get_model_specific_args()
args += extra_args
@@ -425,11 +410,10 @@ def get_iree_compiled_module(
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args, debug
module, device, frontend, model_config_path, extra_args
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -485,11 +469,10 @@ def export_iree_module_to_vmfb(
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, model_config_path, extra_args, debug
module, device, mlir_dialect, model_config_path, extra_args
)
if module_name is None:
device_name = (

View File

@@ -115,7 +115,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout

View File

@@ -192,9 +192,7 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
):
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
@@ -202,7 +200,6 @@ class SharkInference:
self.mlir_dialect,
module_name=module_name,
extra_args=extra_args,
debug=debug,
)
# load and return the module.

View File

@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name, debug=False)
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path