mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
added check so quantization is not performed when not necessary
This commit is contained in:
@@ -535,6 +535,11 @@ class ShardedVicuna(VicunaBase):
|
||||
self.dir_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def check_all_artifacts_present(self):
|
||||
file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"]
|
||||
file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list]
|
||||
return all(file_exists_list)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {}
|
||||
if "llama2" in self.model_name:
|
||||
@@ -1321,34 +1326,40 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if self.precision in ["int4", "int8"]:
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
if not self.check_all_artifacts_present():
|
||||
print("Applying weight quantization..")
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
|
||||
print("Weight quantization applied.")
|
||||
|
||||
else:
|
||||
print("Skipping quantization, as all required artifacts are present")
|
||||
|
||||
placeholder_pkv_segment = tuple(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user