added check so quantization is not performed when not necessary

This commit is contained in:
Elias Joseph
2023-12-06 01:16:11 -08:00
parent 0c0329699a
commit 6cefb450bf

View File

@@ -535,6 +535,11 @@ class ShardedVicuna(VicunaBase):
self.dir_path.mkdir(parents=True, exist_ok=True)
self.shark_model = self.compile(device=device)
def check_all_artifacts_present(self):
file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"]
file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list]
return all(file_exists_list)
def get_tokenizer(self):
kwargs = {}
if "llama2" in self.model_name:
@@ -1321,34 +1326,40 @@ class ShardedVicuna(VicunaBase):
)
if self.precision in ["int4", "int8"]:
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Applying weight quantization..")
weight_bit_width = 4 if self.precision == "int4" else 8
quantize_model(
get_model_impl(vicuna_model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
if not self.check_all_artifacts_present():
print("Applying weight quantization..")
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
weight_bit_width = 4 if self.precision == "int4" else 8
quantize_model(
get_model_impl(vicuna_model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")
else:
print("Skipping quantization, as all required artifacts are present")
placeholder_pkv_segment = tuple(
(