From 1df20fac95e42028a82d45fa3676cb95a6c058be Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 22 Sep 2022 18:02:09 -0400 Subject: [PATCH] [Lockstep] Hack to avoid aten._reshape_alias (#332) This enforces the decomposition for aten._reshape_alias used in AOTAutograd to essentially avoid having to deal with problems with strides when running in eager mode. --- shark/torch_mlir_lockstep_tensor.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/shark/torch_mlir_lockstep_tensor.py b/shark/torch_mlir_lockstep_tensor.py index a96b561e..bfcaafed 100644 --- a/shark/torch_mlir_lockstep_tensor.py +++ b/shark/torch_mlir_lockstep_tensor.py @@ -151,7 +151,12 @@ class TorchMLIRLockstepTensor(TorchMLIRTensor): with no_dispatch(): unwrapped_args = tree_map(cls.unwrap, args) unwrapped_kwargs = tree_map(cls.unwrap, kwargs) - native_out = func(*unwrapped_args, **unwrapped_kwargs) + if "_reshape_alias" in op_name: + native_out = torch.ops.aten.view( + unwrapped_args[0], unwrapped_args[1] + ) + else: + native_out = func(*unwrapped_args, **unwrapped_kwargs) native_out = tree_map( lambda x: cls(x, requires_grad=requires_grad), native_out @@ -195,7 +200,12 @@ class TorchMLIRLockstepTensor(TorchMLIRTensor): with no_dispatch(): unwrapped_args = tree_map(cls.unwrap, args) unwrapped_kwargs = tree_map(cls.unwrap, kwargs) - out = func(*unwrapped_args, **unwrapped_kwargs) + if "_reshape_alias" in op_name: + out = torch.ops.aten.view( + unwrapped_args[0], unwrapped_args[1] + ) + else: + out = func(*unwrapped_args, **unwrapped_kwargs) out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)