mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Minor changes in the fx transforms.
This commit is contained in:
@@ -317,16 +317,19 @@ def add_upcast(fx_g):
|
||||
import torch
|
||||
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.target in [torch.ops.aten.rsqrt]:
|
||||
if node.target in [torch.ops.aten.mul]:
|
||||
# This is a very strict check.
|
||||
if (
|
||||
node.args[0].target in [torch.ops.aten.add]
|
||||
and node.args[0].args[0].target in [torch.ops.aten.mean]
|
||||
and node.args[0].args[0].args[0].target in [torch.ops.aten.pow]
|
||||
node.args[1].target in [torch.ops.aten.rsqrt]
|
||||
and node.args[1].args[0].target in [torch.ops.aten.add]
|
||||
and node.args[1].args[0].args[0].target
|
||||
in [torch.ops.aten.mean]
|
||||
and node.args[1].args[0].args[0].args[0].target
|
||||
in [torch.ops.aten.pow]
|
||||
):
|
||||
print("found an upcasting block let's upcast it.")
|
||||
pow_node = node.args[0].args[0].args[0]
|
||||
rsqrt_node = node
|
||||
pow_node = node.args[1].args[0].args[0].args[0]
|
||||
mul_node = node
|
||||
with fx_g.graph.inserting_before(pow_node):
|
||||
lhs = pow_node.args[0]
|
||||
upcast_lhs = fx_g.graph.call_function(
|
||||
@@ -335,15 +338,15 @@ def add_upcast(fx_g):
|
||||
kwargs={"dtype": torch.float32},
|
||||
)
|
||||
pow_node.args = (upcast_lhs, pow_node.args[1])
|
||||
with fx_g.graph.inserting_before(rsqrt_node):
|
||||
with fx_g.graph.inserting_before(mul_node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten._to_copy,
|
||||
args=(rsqrt_node,),
|
||||
args=(mul_node,),
|
||||
kwargs={"dtype": torch.float16},
|
||||
)
|
||||
rsqrt_node.append(new_node)
|
||||
rsqrt_node.replace_all_uses_with(new_node)
|
||||
new_node.args = (rsqrt_node,)
|
||||
mul_node.append(new_node)
|
||||
mul_node.replace_all_uses_with(new_node)
|
||||
new_node.args = (mul_node,)
|
||||
new_node.kwargs = {"dtype": torch.float16}
|
||||
|
||||
fx_g.graph.lint()
|
||||
@@ -433,6 +436,14 @@ def transform_fx(fx_g):
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
# Required for cuda debugging.
|
||||
# for node in fx_g.graph.nodes:
|
||||
# if node.op == "call_function":
|
||||
# if node.kwargs.get("device") == torch.device(type="cpu"):
|
||||
# new_kwargs = node.kwargs.copy()
|
||||
# new_kwargs["device"] = torch.device(type="cuda")
|
||||
# node.kwargs = new_kwargs
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user