[Lockstep] Hack to avoid aten._reshape_alias (#332)

This enforces the decomposition for aten._reshape_alias used in AOTAutograd to essentially avoid having to deal with problems with strides when running in eager mode.
This commit is contained in:
Quinn Dawkins
2022-09-22 18:02:09 -04:00
committed by GitHub
parent 991e7043d1
commit 1df20fac95

View File

@@ -151,7 +151,12 @@ class TorchMLIRLockstepTensor(TorchMLIRTensor):
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
native_out = func(*unwrapped_args, **unwrapped_kwargs)
if "_reshape_alias" in op_name:
native_out = torch.ops.aten.view(
unwrapped_args[0], unwrapped_args[1]
)
else:
native_out = func(*unwrapped_args, **unwrapped_kwargs)
native_out = tree_map(
lambda x: cls(x, requires_grad=requires_grad), native_out
@@ -195,7 +200,12 @@ class TorchMLIRLockstepTensor(TorchMLIRTensor):
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
out = func(*unwrapped_args, **unwrapped_kwargs)
if "_reshape_alias" in op_name:
out = torch.ops.aten.view(
unwrapped_args[0], unwrapped_args[1]
)
else:
out = func(*unwrapped_args, **unwrapped_kwargs)
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)