mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add support for Falcon GPTQ
This commit is contained in:
@@ -32,7 +32,7 @@ parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -235,7 +235,12 @@ class Falcon(SharkLLMBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
],
|
||||
]
|
||||
+ [
|
||||
"--iree-llvmcpu-use-fast-min-max-ops",
|
||||
]
|
||||
if self.precision == "int4"
|
||||
else [],
|
||||
debug=self.debug,
|
||||
)
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
|
||||
@@ -500,12 +500,69 @@ def gptq_transforms(fx_g):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node,),
|
||||
kwargs={"dtype": torch.float32},
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
new_node.kwargs = {"dtype": torch.float32}
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs and outputs of aten.mm should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
new_node_arg1 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[1], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, new_node_arg1)
|
||||
|
||||
if type(node.args[0]) == torch.fx.node.Node and node.args[
|
||||
0
|
||||
].target in [torch.ops.aten.mm]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
# Inputs and outputs of aten._softmax should be upcasted to fp32.
|
||||
if node.target in [torch.ops.aten._softmax]:
|
||||
with fx_g.graph.inserting_before(node):
|
||||
new_node_arg0 = fx_g.graph.call_function(
|
||||
torch.ops.prims.convert_element_type,
|
||||
args=(node.args[0], torch.float32),
|
||||
kwargs={},
|
||||
)
|
||||
node.args = (new_node_arg0, node.args[1], node.args[2])
|
||||
|
||||
if (
|
||||
type(node.args[0]) == torch.fx.node.Node
|
||||
and node.args[0].target in [torch.ops.aten._softmax]
|
||||
and node.target in [torch.ops.aten.expand]
|
||||
):
|
||||
with fx_g.graph.inserting_before(node):
|
||||
tmp = node.args[0]
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(node.args[0],),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
node.args[0].append(new_node)
|
||||
node.args[0].replace_all_uses_with(new_node)
|
||||
new_node.args = (tmp,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user