mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add masked_select to tensor.py (#9468)
* add masked_select to tensor.py * fix tests --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -2883,6 +2883,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
|
||||
|
||||
def test_masked_select(self):
|
||||
helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True)
|
||||
helper_test_op([(32, 10)], lambda x: x.masked_select(torch.tensor(True)), lambda x: x.masked_select(Tensor(True)), forward_only=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
|
||||
Reference in New Issue
Block a user