mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
[Bounty] moved index_tensor off cpu in torch_backend (#9916)
* moved index tensor off cpu in torch_backend * added support for None based indexing * fix_to_pass_tests * fix segfault tests
This commit is contained in:
committed by
GitHub
parent
373ca59b7f
commit
55942a8d8e
@@ -100,10 +100,6 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
|
||||
ret = aten._index_put_impl_(self.cpu(), [x.cpu() if isinstance(x, torch.Tensor) else None for x in indices], values.cpu(), accumulate, unsafe).to(self.device)
|
||||
return wrap(unwrap(self).assign(unwrap(ret)))
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
|
||||
|
||||
@torch.library.impl("aten::index_put", "privateuseone")
|
||||
def index_put(self, indices, values, accumulate=False):
|
||||
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
|
||||
@@ -137,6 +133,10 @@ for i in [
|
||||
|
||||
# *** end bad functions on CPU ***
|
||||
|
||||
@torch.library.impl("aten::index.Tensor", "privateuseone")
|
||||
def index_tensor(x, y):
|
||||
return wrap(unwrap(x)[[unwrap(_y.to(x.device)) if _y is not None else slice(None) for _y in y]])
|
||||
|
||||
@torch.library.impl("aten::zero_", "privateuseone")
|
||||
@inplace_fn("x")
|
||||
def zero_(x):
|
||||
|
||||
Reference in New Issue
Block a user