[Bounty] moved index_tensor off cpu in torch_backend (#9916)

* moved index tensor off cpu in torch_backend

* added support for None based indexing

* fix_to_pass_tests

* fix segfault tests
This commit is contained in:
Nishant Rajadhyaksha
2025-04-24 11:12:37 -07:00
committed by GitHub
parent 373ca59b7f
commit 55942a8d8e

View File

@@ -100,10 +100,6 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
return wrap(unwrap(self).assign(unwrap(ret)))
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
@torch.library.impl("aten::index_put", "privateuseone")
def index_put(self, indices, values, accumulate=False):
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
@@ -137,6 +133,10 @@ for i in [
# *** end bad functions on CPU ***
@torch.library.impl("aten::index.Tensor", "privateuseone")
def index_tensor(x, y):
return wrap(unwrap(x)[[unwrap(_y.to(x.device)) if _y is not None else slice(None) for _y in y]])
@torch.library.impl("aten::zero_", "privateuseone")
@inplace_fn("x")
def zero_(x):