Add support for Falcon GPTQ

This commit is contained in:
Vivek Khandelwal
2023-10-10 11:55:45 +00:00
parent a731eb6ed4
commit 0a618e1863
2 changed files with 66 additions and 4 deletions

View File

@@ -32,7 +32,7 @@ parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -235,7 +235,12 @@ class Falcon(SharkLLMBase):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))

View File

@@ -500,12 +500,69 @@ def gptq_transforms(fx_g):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float32},
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float32}
new_node.kwargs = {"dtype": torch.float16}
# Inputs and outputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
new_node_arg1 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[1], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, new_node_arg1)
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs and outputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, node.args[1], node.args[2])
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
and node.target in [torch.ops.aten.expand]
):
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()