mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
14 Commits
20230901.9
...
20230824.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a638d1d5d4 | ||
|
|
8298865bda | ||
|
|
b086bf7d4f | ||
|
|
e644fdf38a | ||
|
|
ac01cfa5cc | ||
|
|
c22416cbb5 | ||
|
|
7d77d6cfb2 | ||
|
|
c9cdc8f3c7 | ||
|
|
3f33ea0f46 | ||
|
|
5916e1c89e | ||
|
|
5954a0563c | ||
|
|
c73b805719 | ||
|
|
0d787c7c80 | ||
|
|
6f05a8b934 |
@@ -237,7 +237,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -46,7 +46,6 @@ def compile_stableLM(
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
@@ -93,7 +92,7 @@ def compile_stableLM(
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
SecondVicuna13B,
|
||||
SecondVicuna70B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
@@ -192,6 +191,7 @@ class VicunaBase(SharkLLMBase):
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
model_name=None,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
@@ -355,7 +355,8 @@ class VicunaBase(SharkLLMBase):
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
if not (model_name and "llama2_13b" in model_name):
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
@@ -441,7 +442,6 @@ class ShardedVicuna(VicunaBase):
|
||||
weight_group_size=128,
|
||||
compressed=False,
|
||||
extra_args_cmd=[],
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -452,7 +452,6 @@ class ShardedVicuna(VicunaBase):
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.config = config_json
|
||||
self.weight_group_size = weight_group_size
|
||||
@@ -640,7 +639,7 @@ class ShardedVicuna(VicunaBase):
|
||||
return device_idx
|
||||
|
||||
def compile_lmhead(
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None,
|
||||
self, lmh, hidden_states, device="cpu", device_idx=None
|
||||
):
|
||||
# compile the lm head of the vicuna model
|
||||
# This can be used for both first and second vicuna, so only needs to be run once
|
||||
@@ -688,7 +687,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="lmhead", debug=self.debug)
|
||||
shark_module.save_module(module_name="lmhead")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = LMHeadCompiled(shark_module)
|
||||
return compiled_module
|
||||
@@ -734,7 +733,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="norm", debug=self.debug)
|
||||
shark_module.save_module(module_name="norm")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaNormCompiled(shark_module)
|
||||
return compiled_module
|
||||
@@ -785,14 +784,14 @@ class ShardedVicuna(VicunaBase):
|
||||
if vmfb_path.exists():
|
||||
shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
shark_module.save_module(module_name="embedding", debug=self.debug)
|
||||
shark_module.save_module(module_name="embedding")
|
||||
shark_module.load_module(vmfb_path)
|
||||
compiled_module = VicunaEmbeddingCompiled(shark_module)
|
||||
|
||||
return compiled_module
|
||||
|
||||
def compile_to_vmfb_one_model(
|
||||
self, inputs0, layers0, inputs1, layers1, device="cpu",
|
||||
self, inputs0, layers0, inputs1, layers1, device="cpu"
|
||||
):
|
||||
mlirs, modules = [], []
|
||||
assert len(layers0) == len(layers1)
|
||||
@@ -802,6 +801,7 @@ class ShardedVicuna(VicunaBase):
|
||||
# if vmfb_path.exists():
|
||||
# continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
@@ -855,7 +855,7 @@ class ShardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module0,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
@@ -899,7 +899,7 @@ class ShardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module1,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
@@ -924,6 +924,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlirs.append(module_combined)
|
||||
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
@@ -955,7 +956,6 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -972,6 +972,7 @@ class ShardedVicuna(VicunaBase):
|
||||
# if vmfb_path.exists():
|
||||
# continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
@@ -990,6 +991,7 @@ class ShardedVicuna(VicunaBase):
|
||||
mlirs.append(bytecode)
|
||||
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
@@ -1021,7 +1023,6 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
@@ -1230,13 +1231,12 @@ class UnshardedVicuna(VicunaBase):
|
||||
precision="int8",
|
||||
vicuna_mlir_path=None,
|
||||
vicuna_vmfb_path=None,
|
||||
load_mlir_from_shark_tank=False,
|
||||
load_mlir_from_shark_tank=True,
|
||||
low_device_memory=False,
|
||||
weight_group_size=128,
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -1265,7 +1265,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
|
||||
self.low_device_memory = low_device_memory
|
||||
self.weight_group_size = weight_group_size
|
||||
self.debug = debug
|
||||
if self.vicuna_mlir_path == None:
|
||||
self.vicuna_mlir_path = self.get_model_path()
|
||||
if self.vicuna_vmfb_path == None:
|
||||
@@ -1276,7 +1275,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix in ["mlirbc", "mlir"]:
|
||||
if suffix == "mlir":
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"{self.model_name}_{self.precision}_{safe_device}.{suffix}"
|
||||
@@ -1336,7 +1335,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def write_in_dynamic_inputs1(self, module):
|
||||
def write_in_dynamic_inputs1(self, module, model_name):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
@@ -1364,12 +1363,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in self.model_name:
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in self.model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if self.precision in ["fp16", "int4", "int8"]:
|
||||
@@ -1404,13 +1400,13 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def compile(self):
|
||||
def compile(self, download_vmfb=False):
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
if not self.vicuna_vmfb_path.exists() and self.download_vmfb:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
|
||||
)
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}"
|
||||
)
|
||||
if not self.vicuna_vmfb_path.exists() and download_vmfb:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
|
||||
self.vicuna_vmfb_path.absolute(),
|
||||
@@ -1423,237 +1419,233 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(f"[DEBUG] vmfb not found")
|
||||
mlir_generated = False
|
||||
for suffix in ["mlirbc", "mlir"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank:
|
||||
print(
|
||||
f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}"
|
||||
)
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
mlir_generated = True
|
||||
break
|
||||
|
||||
if not mlir_generated:
|
||||
print(f"[DEBUG] mlir not found")
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
|
||||
first_model_path = f"first_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(first_model_path).exists():
|
||||
print(f"loading {first_model_path}")
|
||||
with open(Path(first_model_path), "r") as f:
|
||||
first_module = f.read()
|
||||
else:
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[
|
||||
0
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
first_module = self.write_in_dynamic_inputs0(
|
||||
str(first_module), dynamic_input_size=19
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(first_model_path, "w+") as f:
|
||||
f.write(first_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
print(f"[DEBUG] Starting generation of second llama")
|
||||
second_model_path = f"second_{self.model_name}_{self.precision}.mlir"
|
||||
if Path(second_model_path).exists():
|
||||
print(f"loading {second_model_path}")
|
||||
with open(Path(second_model_path), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros(
|
||||
[1, 1], dtype=torch.int64
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
dim1 = 40
|
||||
total_tuple = 80
|
||||
elif self.model_name == "llama2_70b":
|
||||
dim1 = 8
|
||||
total_tuple = 160
|
||||
else:
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
elif self.model_name == "llama2_70b":
|
||||
model = SecondVicuna70B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
second_module = self.write_in_dynamic_inputs1(
|
||||
str(second_module)
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(second_model_path, "w+") as f:
|
||||
f.write(second_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
combined_module = self.combine_mlir_scripts(
|
||||
first_module,
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
print(f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}")
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
else:
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.vicuna_mlir_path.absolute()}"
|
||||
)
|
||||
del first_module, second_module
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
# download MLIR from shark tank
|
||||
for suffix in ["mlir", "mlirbc"]:
|
||||
self.vicuna_mlir_path = self.get_model_path(suffix)
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
mlir_generated = True
|
||||
break
|
||||
self.vicuna_mlir_path = self.get_model_path("mlir")
|
||||
if not mlir_generated:
|
||||
print(
|
||||
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
|
||||
if Path(f"first_{self.precision}.mlir").exists():
|
||||
print(f"loading first_{self.precision}.mlir")
|
||||
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
|
||||
first_module = f.read()
|
||||
else:
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[
|
||||
0
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
first_module = self.write_in_dynamic_inputs0(
|
||||
str(first_module), dynamic_input_size=19
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(f"first_{self.precision}.mlir", "w+") as f:
|
||||
f.write(first_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
if Path(f"second_{self.precision}.mlir").exists():
|
||||
print(f"loading second_{self.precision}.mlir")
|
||||
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros(
|
||||
[1, 1], dtype=torch.int64
|
||||
)
|
||||
if self.model_name == "llama2_13b":
|
||||
dim1 = 40
|
||||
total_tuple = 80
|
||||
else:
|
||||
dim1 = 32
|
||||
total_tuple = 64
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, dim1, 19, 128], dtype=torch.float32))
|
||||
for _ in range(total_tuple)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
if self.model_name == "llama2_13b":
|
||||
model = SecondVicuna13B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
else:
|
||||
model = SecondVicuna7B(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
is_f16 = self.precision in ["fp16", "int4"]
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=is_f16,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision in ["fp16", "int4"]:
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * total_tuple,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
second_module = self.write_in_dynamic_inputs1(
|
||||
str(second_module)
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(f"second_{self.precision}.mlir", "w") as f:
|
||||
f.write(second_module)
|
||||
print("Finished writing IR after dynamic")
|
||||
|
||||
combined_module = self.combine_mlir_scripts(
|
||||
first_module,
|
||||
second_module,
|
||||
self.vicuna_mlir_path,
|
||||
self.model_name,
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
if "rocm" in self.device:
|
||||
@@ -1672,7 +1664,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
+ self.extra_args,
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
@@ -1740,6 +1731,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
yield detok, ""
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
yield res_str, "formatted"
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
@@ -18,7 +18,6 @@ class FirstVicuna(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
@@ -67,7 +66,6 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
@@ -157,6 +155,8 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -301,7 +301,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="int8",
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
@@ -418,6 +418,8 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
i79,
|
||||
i80,
|
||||
):
|
||||
# input_ids = input_tuple[0]
|
||||
# input_tuple = torch.unbind(pkv, dim=0)
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
@@ -590,540 +592,6 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondVicuna70B(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_path,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Second Vicuna applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float16,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
i0,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
i65,
|
||||
i66,
|
||||
i67,
|
||||
i68,
|
||||
i69,
|
||||
i70,
|
||||
i71,
|
||||
i72,
|
||||
i73,
|
||||
i74,
|
||||
i75,
|
||||
i76,
|
||||
i77,
|
||||
i78,
|
||||
i79,
|
||||
i80,
|
||||
i81,
|
||||
i82,
|
||||
i83,
|
||||
i84,
|
||||
i85,
|
||||
i86,
|
||||
i87,
|
||||
i88,
|
||||
i89,
|
||||
i90,
|
||||
i91,
|
||||
i92,
|
||||
i93,
|
||||
i94,
|
||||
i95,
|
||||
i96,
|
||||
i97,
|
||||
i98,
|
||||
i99,
|
||||
i100,
|
||||
i101,
|
||||
i102,
|
||||
i103,
|
||||
i104,
|
||||
i105,
|
||||
i106,
|
||||
i107,
|
||||
i108,
|
||||
i109,
|
||||
i110,
|
||||
i111,
|
||||
i112,
|
||||
i113,
|
||||
i114,
|
||||
i115,
|
||||
i116,
|
||||
i117,
|
||||
i118,
|
||||
i119,
|
||||
i120,
|
||||
i121,
|
||||
i122,
|
||||
i123,
|
||||
i124,
|
||||
i125,
|
||||
i126,
|
||||
i127,
|
||||
i128,
|
||||
i129,
|
||||
i130,
|
||||
i131,
|
||||
i132,
|
||||
i133,
|
||||
i134,
|
||||
i135,
|
||||
i136,
|
||||
i137,
|
||||
i138,
|
||||
i139,
|
||||
i140,
|
||||
i141,
|
||||
i142,
|
||||
i143,
|
||||
i144,
|
||||
i145,
|
||||
i146,
|
||||
i147,
|
||||
i148,
|
||||
i149,
|
||||
i150,
|
||||
i151,
|
||||
i152,
|
||||
i153,
|
||||
i154,
|
||||
i155,
|
||||
i156,
|
||||
i157,
|
||||
i158,
|
||||
i159,
|
||||
i160,
|
||||
):
|
||||
token = i0
|
||||
past_key_values = (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
(
|
||||
i65,
|
||||
i66,
|
||||
),
|
||||
(
|
||||
i67,
|
||||
i68,
|
||||
),
|
||||
(
|
||||
i69,
|
||||
i70,
|
||||
),
|
||||
(
|
||||
i71,
|
||||
i72,
|
||||
),
|
||||
(
|
||||
i73,
|
||||
i74,
|
||||
),
|
||||
(
|
||||
i75,
|
||||
i76,
|
||||
),
|
||||
(
|
||||
i77,
|
||||
i78,
|
||||
),
|
||||
(
|
||||
i79,
|
||||
i80,
|
||||
),
|
||||
(
|
||||
i81,
|
||||
i82,
|
||||
),
|
||||
(
|
||||
i83,
|
||||
i84,
|
||||
),
|
||||
(
|
||||
i85,
|
||||
i86,
|
||||
),
|
||||
(
|
||||
i87,
|
||||
i88,
|
||||
),
|
||||
(
|
||||
i89,
|
||||
i90,
|
||||
),
|
||||
(
|
||||
i91,
|
||||
i92,
|
||||
),
|
||||
(
|
||||
i93,
|
||||
i94,
|
||||
),
|
||||
(
|
||||
i95,
|
||||
i96,
|
||||
),
|
||||
(
|
||||
i97,
|
||||
i98,
|
||||
),
|
||||
(
|
||||
i99,
|
||||
i100,
|
||||
),
|
||||
(
|
||||
i101,
|
||||
i102,
|
||||
),
|
||||
(
|
||||
i103,
|
||||
i104,
|
||||
),
|
||||
(
|
||||
i105,
|
||||
i106,
|
||||
),
|
||||
(
|
||||
i107,
|
||||
i108,
|
||||
),
|
||||
(
|
||||
i109,
|
||||
i110,
|
||||
),
|
||||
(
|
||||
i111,
|
||||
i112,
|
||||
),
|
||||
(
|
||||
i113,
|
||||
i114,
|
||||
),
|
||||
(
|
||||
i115,
|
||||
i116,
|
||||
),
|
||||
(
|
||||
i117,
|
||||
i118,
|
||||
),
|
||||
(
|
||||
i119,
|
||||
i120,
|
||||
),
|
||||
(
|
||||
i121,
|
||||
i122,
|
||||
),
|
||||
(
|
||||
i123,
|
||||
i124,
|
||||
),
|
||||
(
|
||||
i125,
|
||||
i126,
|
||||
),
|
||||
(
|
||||
i127,
|
||||
i128,
|
||||
),
|
||||
(
|
||||
i129,
|
||||
i130,
|
||||
),
|
||||
(
|
||||
i131,
|
||||
i132,
|
||||
),
|
||||
(
|
||||
i133,
|
||||
i134,
|
||||
),
|
||||
(
|
||||
i135,
|
||||
i136,
|
||||
),
|
||||
(
|
||||
i137,
|
||||
i138,
|
||||
),
|
||||
(
|
||||
i139,
|
||||
i140,
|
||||
),
|
||||
(
|
||||
i141,
|
||||
i142,
|
||||
),
|
||||
(
|
||||
i143,
|
||||
i144,
|
||||
),
|
||||
(
|
||||
i145,
|
||||
i146,
|
||||
),
|
||||
(
|
||||
i147,
|
||||
i148,
|
||||
),
|
||||
(
|
||||
i149,
|
||||
i150,
|
||||
),
|
||||
(
|
||||
i151,
|
||||
i152,
|
||||
),
|
||||
(
|
||||
i153,
|
||||
i154,
|
||||
),
|
||||
(
|
||||
i155,
|
||||
i156,
|
||||
),
|
||||
(
|
||||
i157,
|
||||
i158,
|
||||
),
|
||||
(
|
||||
i159,
|
||||
i160,
|
||||
),
|
||||
)
|
||||
op = self.model(
|
||||
input_ids=token, use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(op.logits)
|
||||
temp_past_key_values = op.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class CombinedModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,10 +3,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
self, model_name, hf_model_path=None, max_num_tokens=512
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
|
||||
@@ -71,7 +71,6 @@ class Falcon(SharkLLMBase):
|
||||
precision="fp32",
|
||||
falcon_mlir_path=None,
|
||||
falcon_vmfb_path=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_padding_length = 100
|
||||
@@ -79,7 +78,6 @@ class Falcon(SharkLLMBase):
|
||||
self.precision = precision
|
||||
self.falcon_vmfb_path = falcon_vmfb_path
|
||||
self.falcon_mlir_path = falcon_mlir_path
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
self.src_model = self.get_src_model()
|
||||
@@ -210,7 +208,6 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
"--iree-spirv-index-bits=64",
|
||||
],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
@@ -178,7 +178,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
|
||||
|
||||
def compile_module(
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
|
||||
shark_module, extended_model_name, generate_vmfb, extra_args=[]
|
||||
):
|
||||
if generate_vmfb:
|
||||
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
|
||||
@@ -190,7 +190,7 @@ def compile_module(
|
||||
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), extended_model_name, extra_args, debug=debug
|
||||
os.getcwd(), extended_model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
@@ -199,7 +199,7 @@ def compile_module(
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
):
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
@@ -219,7 +219,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
@@ -251,7 +251,6 @@ def compile_int_precision(
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
@@ -295,7 +294,6 @@ def shark_compile_through_fx_int(
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
debug,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
|
||||
@@ -32,13 +32,11 @@ class SharkStableLM(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
@@ -113,7 +111,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
|
||||
@@ -74,11 +74,8 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree._runtime_libs"]
|
||||
|
||||
@@ -570,14 +570,6 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--compile_debug",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag to toggle debug assert/verify flags for imported IR in the"
|
||||
"iree-compiler. Default to false.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
|
||||
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
|
||||
)
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
os.getcwd(), model_name, extra_args, debug=args.compile_debug
|
||||
os.getcwd(), model_name, extra_args
|
||||
)
|
||||
shark_module.load_module(path, extra_args=extra_args)
|
||||
else:
|
||||
|
||||
@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
|
||||
@@ -160,15 +160,14 @@ def chat(
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
@@ -178,8 +177,6 @@ def chat(
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
@@ -197,6 +194,20 @@ def chat(
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
@@ -228,8 +239,6 @@ def chat(
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
@@ -353,8 +362,6 @@ def llm_chat_api(InputData: dict):
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
@@ -425,14 +432,15 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
# multiselect=True,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -444,13 +452,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
@@ -485,15 +487,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
@@ -501,15 +495,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[
|
||||
system_msg,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot, tokens_time],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -7,13 +7,22 @@ import fileinput
|
||||
from pathlib import Path
|
||||
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_stdhooks = Path(
|
||||
get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks"
|
||||
)
|
||||
path_to_transformers_hook = Path(
|
||||
get_python_lib()
|
||||
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
|
||||
str(path_to_stdhooks) + "hook-transformers.py"
|
||||
)
|
||||
if path_to_transformers_hook.is_file():
|
||||
pass
|
||||
else:
|
||||
if not path_to_stdhooks.is_dir():
|
||||
import os
|
||||
|
||||
print(
|
||||
f"Path to pyinstaller stdhooks not found. Please check your pyinstaller packages at {path_to_stdhooks}."
|
||||
)
|
||||
os.mkdir(path_to_stdhooks)
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
|
||||
@@ -92,27 +92,13 @@ def get_iree_frontend_args(frontend):
|
||||
|
||||
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--iree-opt-strip-assertions=true",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=false",
|
||||
"--verify=true",
|
||||
]
|
||||
)
|
||||
else:
|
||||
common_args.extend(
|
||||
[
|
||||
"--iree-opt-strip-assertions=true",
|
||||
"--verify=false",
|
||||
]
|
||||
)
|
||||
return common_args
|
||||
|
||||
|
||||
# Args that are suitable only for certain models or groups of models.
|
||||
@@ -291,13 +277,12 @@ def compile_module_to_flatbuffer(
|
||||
model_config_path,
|
||||
extra_args,
|
||||
model_name="None",
|
||||
debug=False,
|
||||
):
|
||||
# Setup Compile arguments wrt to frontends.
|
||||
input_type = ""
|
||||
args = get_iree_frontend_args(frontend)
|
||||
args += get_iree_device_args(device, extra_args)
|
||||
args += get_iree_common_args(debug=debug)
|
||||
args += get_iree_common_args()
|
||||
args += get_model_specific_args()
|
||||
args += extra_args
|
||||
|
||||
@@ -425,11 +410,10 @@ def get_iree_compiled_module(
|
||||
extra_args: list = [],
|
||||
device_idx: int = None,
|
||||
mmap: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""Given a module returns the compiled .vmfb and configs"""
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, frontend, model_config_path, extra_args, debug
|
||||
module, device, frontend, model_config_path, extra_args
|
||||
)
|
||||
temp_file_to_unlink = None
|
||||
# TODO: Currently mmap=True control flow path has been switched off for mmap.
|
||||
@@ -485,11 +469,10 @@ def export_iree_module_to_vmfb(
|
||||
model_config_path: str = None,
|
||||
module_name: str = None,
|
||||
extra_args: list = [],
|
||||
debug: bool = False,
|
||||
):
|
||||
# Compiles the module given specs and saves it as .vmfb file.
|
||||
flatbuffer_blob = compile_module_to_flatbuffer(
|
||||
module, device, mlir_dialect, model_config_path, extra_args, debug
|
||||
module, device, mlir_dialect, model_config_path, extra_args
|
||||
)
|
||||
if module_name is None:
|
||||
device_name = (
|
||||
|
||||
@@ -115,7 +115,7 @@ def compile_int_precision(
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
@@ -192,9 +192,7 @@ class SharkInference:
|
||||
|
||||
# TODO: Instead of passing directory and having names decided by the module
|
||||
# , user may want to save the module with manual names.
|
||||
def save_module(
|
||||
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
|
||||
):
|
||||
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
|
||||
return export_iree_module_to_vmfb(
|
||||
self.mlir_module,
|
||||
self.device,
|
||||
@@ -202,7 +200,6 @@ class SharkInference:
|
||||
self.mlir_dialect,
|
||||
module_name=module_name,
|
||||
extra_args=extra_args,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
# load and return the module.
|
||||
|
||||
@@ -59,7 +59,7 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
shark_module.save_module(module_name=vmfb_name)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
|
||||
|
||||
Reference in New Issue
Block a user