diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index de2ec698..a661be24 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -535,6 +535,11 @@ class ShardedVicuna(VicunaBase): self.dir_path.mkdir(parents=True, exist_ok=True) self.shark_model = self.compile(device=device) + def check_all_artifacts_present(self): + file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"] + file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list] + return all(file_exists_list) + def get_tokenizer(self): kwargs = {} if "llama2" in self.model_name: @@ -1321,34 +1326,40 @@ class ShardedVicuna(VicunaBase): ) if self.precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - print("Applying weight quantization..") - weight_bit_width = 4 if self.precision == "int4" else 8 - quantize_model( - get_model_impl(vicuna_model).layers, - dtype=torch.float32, - weight_quant_type="asym", - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_granularity="per_group", - weight_group_size=self.weight_group_size, - quantize_weight_zero_point=False, - input_bit_width=None, - input_scale_type="float", - input_param_method="stats", - input_quant_type="asym", - input_quant_granularity="per_tensor", - quantize_input_zero_point=False, - seqlen=2048, - ) - print("Weight quantization applied.") + if not self.check_all_artifacts_present(): + print("Applying weight quantization..") + from brevitas_examples.common.generative.quantize import ( + quantize_model, + ) + from brevitas_examples.llm.llm_quant.run_utils import ( + get_model_impl, + ) + weight_bit_width = 4 if self.precision == "int4" else 8 + + quantize_model( + get_model_impl(vicuna_model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float_scale", + weight_quant_granularity="per_group", + weight_group_size=self.weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + + print("Weight quantization applied.") + + else: + print("Skipping quantization, as all required artifacts are present") placeholder_pkv_segment = tuple( (