fix gather with inf values (#11760)

(mask * x) is wrong because 0*inf is nan. i feel we have a lot of those still...
This commit is contained in:
chenyu
2025-08-20 20:35:40 -04:00
committed by GitHub
parent b979162c5d
commit 5276fbc9c5
3 changed files with 3 additions and 7 deletions

View File

@@ -2804,11 +2804,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
vals=[[1., 2., 3.]])
@unittest.expectedFailure
@unittest.skipIf(torch._C._get_privateuse1_backend_name() == "tiny", 'results in a success instead of a failure')
def test_gather_failure(self):
# gather with inf values do not work, other values results in nan
# gather with inf values
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
vals=[[-float("inf"), 2., 3.]])