diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 56046e7cf0..bd39568b50 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -100,10 +100,6 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device) return wrap(unwrap(self).assign(unwrap(ret))) -@torch.library.impl("aten::index.Tensor", "privateuseone") -def index_tensor(x, y): - return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device) - @torch.library.impl("aten::index_put", "privateuseone") def index_put(self, indices, values, accumulate=False): return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny() @@ -137,6 +133,10 @@ for i in [ # *** end bad functions on CPU *** +@torch.library.impl("aten::index.Tensor", "privateuseone") +def index_tensor(x, y): + return wrap(unwrap(x)[[unwrap(_y.to(x.device)) if _y is not None else slice(None) for _y in y]]) + @torch.library.impl("aten::zero_", "privateuseone") @inplace_fn("x") def zero_(x):