From 8bb364bcb806163272fbba0061b79d295236bfcb Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:34:49 -0500 Subject: [PATCH] enforce fp32 accumulates for cpu (#1873) --- apps/language_models/scripts/vicuna.py | 4 ++++ .../src/model_wrappers/vicuna_model.py | 19 ++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 9cc0b509..8727f0f1 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1504,6 +1504,7 @@ class UnshardedVicuna(VicunaBase): model = FirstVicuna( self.hf_model_path, self.precision, + "fp32" if self.device=="cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, @@ -1600,6 +1601,7 @@ class UnshardedVicuna(VicunaBase): model = SecondVicuna13B( self.hf_model_path, self.precision, + "fp32" if self.device=="cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, @@ -1608,6 +1610,7 @@ class UnshardedVicuna(VicunaBase): model = SecondVicuna70B( self.hf_model_path, self.precision, + "fp32" if self.device=="cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, @@ -1616,6 +1619,7 @@ class UnshardedVicuna(VicunaBase): model = SecondVicuna7B( self.hf_model_path, self.precision, + "fp32" if self.device=="cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index fa06e468..b81d8388 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -6,7 +6,7 @@ class FirstVicuna(torch.nn.Module): def __init__( self, model_path, - precision="fp32", + precision="fp32",accumulates="fp32", weight_group_size=128, model_name="vicuna", hf_auth_token: str = None, @@ -15,6 +15,7 @@ class FirstVicuna(torch.nn.Module): kwargs = {"torch_dtype": torch.float32} if "llama2" in model_name: kwargs["use_auth_token"] = hf_auth_token + self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16 self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) @@ -29,7 +30,7 @@ class FirstVicuna(torch.nn.Module): weight_bit_width = 4 if precision == "int4" else 8 quantize_model( get_model_impl(self.model).layers, - dtype=torch.float16 if precision == "int4" else torch.float32, + dtype=self.accumulates, weight_bit_width=weight_bit_width, weight_param_method="stats", weight_scale_precision="float", @@ -57,7 +58,7 @@ class SecondVicuna7B(torch.nn.Module): def __init__( self, model_path, - precision="fp32", + precision="fp32",accumulates="fp32", weight_group_size=128, model_name="vicuna", hf_auth_token: str = None, @@ -69,6 +70,7 @@ class SecondVicuna7B(torch.nn.Module): self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16 print(f"[DEBUG] model_path : {model_path}") if precision in ["int4", "int8"]: from brevitas_examples.llm.llm_quant.quantize import quantize_model @@ -80,7 +82,7 @@ class SecondVicuna7B(torch.nn.Module): weight_bit_width = 4 if precision == "int4" else 8 quantize_model( get_model_impl(self.model).layers, - dtype=torch.float16 if precision == "int4" else torch.float32, + dtype=self.accumulates, weight_bit_width=weight_bit_width, weight_param_method="stats", weight_scale_precision="float", @@ -305,6 +307,7 @@ class SecondVicuna13B(torch.nn.Module): self, model_path, precision="int8", + accumulates="fp32", weight_group_size=128, model_name="vicuna", hf_auth_token: str = None, @@ -316,6 +319,7 @@ class SecondVicuna13B(torch.nn.Module): self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16 if precision in ["int4", "int8"]: from brevitas_examples.llm.llm_quant.quantize import quantize_model from brevitas_examples.llm.llm_quant.run_utils import ( @@ -326,7 +330,7 @@ class SecondVicuna13B(torch.nn.Module): weight_bit_width = 4 if precision == "int4" else 8 quantize_model( get_model_impl(self.model).layers, - dtype=torch.float16 if precision == "int4" else torch.float32, + dtype=self.accumulates, weight_bit_width=weight_bit_width, weight_param_method="stats", weight_scale_precision="float", @@ -597,7 +601,7 @@ class SecondVicuna70B(torch.nn.Module): def __init__( self, model_path, - precision="fp32", + precision="fp32",accumulates="fp32", weight_group_size=128, model_name="vicuna", hf_auth_token: str = None, @@ -609,6 +613,7 @@ class SecondVicuna70B(torch.nn.Module): self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16 print(f"[DEBUG] model_path : {model_path}") if precision in ["int4", "int8"]: from brevitas_examples.llm.llm_quant.quantize import quantize_model @@ -620,7 +625,7 @@ class SecondVicuna70B(torch.nn.Module): weight_bit_width = 4 if precision == "int4" else 8 quantize_model( get_model_impl(self.model).layers, - dtype=torch.float16, + dtype=self.accumulates, weight_bit_width=weight_bit_width, weight_param_method="stats", weight_scale_precision="float",