add masked_select to tensor.py (#9468)

* add masked_select to tensor.py

* fix tests

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
TJ
2025-03-17 13:05:36 -07:00
committed by GitHub
parent 4f8eac59ea
commit 9fcef4d009
4 changed files with 35 additions and 5 deletions

View File

@@ -2883,6 +2883,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
def test_masked_select(self):
helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True)
helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True)
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())