enforce fp32 accumulates for cpu (#1873)

This commit is contained in:
Daniel Garvey
2023-10-06 11:34:49 -05:00
committed by GitHub
parent 7abddd01ec
commit 8bb364bcb8
2 changed files with 16 additions and 7 deletions

View File

@@ -1504,6 +1504,7 @@ class UnshardedVicuna(VicunaBase):
model = FirstVicuna(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
@@ -1600,6 +1601,7 @@ class UnshardedVicuna(VicunaBase):
model = SecondVicuna13B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
@@ -1608,6 +1610,7 @@ class UnshardedVicuna(VicunaBase):
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,
@@ -1616,6 +1619,7 @@ class UnshardedVicuna(VicunaBase):
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
"fp32" if self.device=="cpu" else "fp16",
self.weight_group_size,
self.model_name,
self.hf_auth_token,

View File

@@ -6,7 +6,7 @@ class FirstVicuna(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
precision="fp32",accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -15,6 +15,7 @@ class FirstVicuna(torch.nn.Module):
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
@@ -29,7 +30,7 @@ class FirstVicuna(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -57,7 +58,7 @@ class SecondVicuna7B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
precision="fp32",accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -69,6 +70,7 @@ class SecondVicuna7B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
@@ -80,7 +82,7 @@ class SecondVicuna7B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -305,6 +307,7 @@ class SecondVicuna13B(torch.nn.Module):
self,
model_path,
precision="int8",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -316,6 +319,7 @@ class SecondVicuna13B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
@@ -326,7 +330,7 @@ class SecondVicuna13B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16 if precision == "int4" else torch.float32,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -597,7 +601,7 @@ class SecondVicuna70B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
precision="fp32",accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -609,6 +613,7 @@ class SecondVicuna70B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = torch.float32 if accumulates=="fp32" else torch.float16
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
@@ -620,7 +625,7 @@ class SecondVicuna70B(torch.nn.Module):
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float16,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",