mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add fix for attention slicing fp16 (#1217)
This commit is contained in:
@@ -297,6 +297,7 @@ def transform_fx(fx_g):
|
||||
if node.target in [
|
||||
torch.ops.aten.arange,
|
||||
torch.ops.aten.empty,
|
||||
torch.ops.aten.zeros,
|
||||
]:
|
||||
node.kwargs = kwargs_dict
|
||||
# Inputs and outputs of aten.var.mean should be upcasted to fp32.
|
||||
|
||||
Reference in New Issue
Block a user