mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Tensor.where method (#830)
This commit is contained in:
@@ -66,13 +66,13 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(
|
||||
[(100,)],
|
||||
lambda x: torch.where(x > 0.5, 4, 2),
|
||||
lambda x: Tensor.where(x > 0.5, 4, 2), forward_only=True)
|
||||
lambda x: (x > 0.5).where(4, 2), forward_only=True)
|
||||
|
||||
for shps in [[(10,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]:
|
||||
helper_test_op(
|
||||
shps,
|
||||
lambda x, a, b: torch.where(x > 0.5, a, b),
|
||||
lambda x, a, b: Tensor.where(x > 0.5, a, b), forward_only=True)
|
||||
lambda x, a, b: (x > 0.5).where(a, b), forward_only=True)
|
||||
|
||||
def _test_cmp(self, fxn, reverse=True):
|
||||
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
|
||||
|
||||
Reference in New Issue
Block a user