Add fix for attention slicing fp16 (#1217)

This commit is contained in:
gpetters94
2023-03-20 22:11:29 -04:00
committed by GitHub
parent d105246b9c
commit 7899e1803a

View File

@@ -297,6 +297,7 @@ def transform_fx(fx_g):
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.zeros,
]:
node.kwargs = kwargs_dict
# Inputs and outputs of aten.var.mean should be upcasted to fp32.