fix gather with inf values (#11760)

(mask * x) is wrong because 0*inf is nan. i feel we have a lot of those still...
This commit is contained in:
chenyu
2025-08-20 20:35:40 -04:00
committed by GitHub
parent b979162c5d
commit 5276fbc9c5
3 changed files with 3 additions and 7 deletions

View File

@@ -605,7 +605,7 @@ jobs:
deps: testing_minimal
- name: Test CPU=1 RANGEIFY=1
# TODO: add more passing tests here
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py -k "not test_gather_failure" --durations 20
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
testdevectorize:
name: Linux (devectorize)

View File

@@ -2804,11 +2804,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
vals=[[1., 2., 3.]])
@unittest.expectedFailure
@unittest.skipIf(torch._C._get_privateuse1_backend_name() == "tiny", 'results in a success instead of a failure')
def test_gather_failure(self):
# gather with inf values do not work, other values results in nan
# gather with inf values
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
vals=[[-float("inf"), 2., 3.]])

View File

@@ -1303,7 +1303,7 @@ class Tensor(MathTrait):
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
index = index.to(self.device)
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype)
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""