diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 508bcb39da..610a83f1fd 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -64,11 +64,6 @@ def inplace_fn(outvars: str|list[str]): # *** bad functions on CPU *** -@torch.library.impl("aten::masked_select", "privateuseone") -def masked_select(self, mask): - # err, bad - return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()], device=_from_torch_device(self.device))) - @torch.library.impl("aten::_index_put_impl_", "privateuseone") @inplace_fn("self") def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False): @@ -418,6 +413,7 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{ "aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))), "aten.masked_fill.Scalar": Tensor.masked_fill, "aten.masked_fill.Tensor": Tensor.masked_fill, + "aten.masked_select": Tensor.masked_select, "aten.all": Tensor.all, "aten.sgn": Tensor.sign, "aten.acos": Tensor.acos, diff --git a/extra/torch_backend/test.py b/extra/torch_backend/test.py index 39634ca577..49e9ccd03e 100644 --- a/extra/torch_backend/test.py +++ b/extra/torch_backend/test.py @@ -98,6 +98,15 @@ class TestTorchBackend(unittest.TestCase): np.testing.assert_equal(out.values.cpu().numpy(), [4, 3]) np.testing.assert_equal(out.indices.cpu().numpy(), [3, 1]) + def test_masked_select(self): + a = torch.tensor([4, 3, 2, 1], device=device) + mask = torch.tensor([True, False, True, False], device=device) + out = torch.masked_select(a, mask) + np.testing.assert_equal(out.cpu().numpy(), [4, 2]) + mask = torch.tensor(True, device=device) + out = torch.masked_select(a, mask) + np.testing.assert_equal(out.cpu().numpy(), [4, 3, 2, 1]) + @unittest.skip("meh") def test_str(self): a = torch.ones(4, device=device) diff --git a/test/test_ops.py b/test/test_ops.py index 5fc60c0686..86483a7d44 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2883,6 +2883,10 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf)) + def test_masked_select(self): + helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True) + helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True) + @unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)") def test_cast(self): helper_test_op([(3, 3)], lambda x: x.float()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5e979a6098..c6ff2d1b1e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1556,6 +1556,27 @@ class Tensor(SimpleMathTrait): t = t.permute([lhs.index(name) for name in rhs]) return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] Tensor: