Fix Falcon GPTQ Pipeline

This commit is contained in:
Vivek Khandelwal
2023-10-11 12:28:06 +00:00
parent 0a618e1863
commit b83d32fafe
2 changed files with 15 additions and 23 deletions

View File

@@ -9,7 +9,7 @@ from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -118,11 +118,17 @@ class Falcon(SharkLLMBase):
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
"device_map": "cpu" if args.device == "cpu" else "cuda:0",
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):
@@ -194,7 +200,7 @@ class Falcon(SharkLLMBase):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
@@ -229,7 +235,7 @@ class Falcon(SharkLLMBase):
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path,
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
@@ -417,7 +423,7 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs

View File

@@ -491,23 +491,7 @@ def gptq_transforms(fx_g):
node.args[4],
)
# Downcasting the result of native_layer_norm back to fp16.
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [
torch.ops.aten.native_layer_norm
]:
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs and outputs of aten.mm should be upcasted to fp32.
# Inputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
@@ -522,6 +506,7 @@ def gptq_transforms(fx_g):
)
node.args = (new_node_arg0, new_node_arg1)
# Outputs of aten.mm should be downcasted to fp16.
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
@@ -537,7 +522,7 @@ def gptq_transforms(fx_g):
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs and outputs of aten._softmax should be upcasted to fp32.
# Inputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
@@ -547,6 +532,7 @@ def gptq_transforms(fx_g):
)
node.args = (new_node_arg0, node.args[1], node.args[2])
# Outputs of aten._softmax should be downcasted to fp16.
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]