Tensor.where method (#830)

This commit is contained in:
Kirill R
2023-05-28 20:20:33 +03:00
committed by GitHub
parent eea3542975
commit 081b3ab639
2 changed files with 5 additions and 6 deletions

View File

@@ -66,13 +66,13 @@ class TestOps(unittest.TestCase):
helper_test_op(
[(100,)],
lambda x: torch.where(x > 0.5, 4, 2),
lambda x: Tensor.where(x > 0.5, 4, 2), forward_only=True)
lambda x: (x > 0.5).where(4, 2), forward_only=True)
for shps in [[(10,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]:
helper_test_op(
shps,
lambda x, a, b: torch.where(x > 0.5, a, b),
lambda x, a, b: Tensor.where(x > 0.5, a, b), forward_only=True)
lambda x, a, b: (x > 0.5).where(a, b), forward_only=True)
def _test_cmp(self, fxn, reverse=True):
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: