mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix gather with inf values (#11760)
(mask * x) is wrong because 0*inf is nan. i feel we have a lot of those still...
This commit is contained in:
@@ -2804,11 +2804,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
|
||||
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
|
||||
vals=[[1., 2., 3.]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(torch._C._get_privateuse1_backend_name() == "tiny", 'results in a success instead of a failure')
|
||||
def test_gather_failure(self):
|
||||
# gather with inf values do not work, other values results in nan
|
||||
# gather with inf values
|
||||
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
|
||||
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
|
||||
vals=[[-float("inf"), 2., 3.]])
|
||||
|
||||
Reference in New Issue
Block a user