mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
fix gather with inf values (#11760)
(mask * x) is wrong because 0*inf is nan. i feel we have a lot of those still...
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -605,7 +605,7 @@ jobs:
|
||||
deps: testing_minimal
|
||||
- name: Test CPU=1 RANGEIFY=1
|
||||
# TODO: add more passing tests here
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py -k "not test_gather_failure" --durations 20
|
||||
run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
|
||||
testdevectorize:
|
||||
name: Linux (devectorize)
|
||||
|
||||
@@ -2804,11 +2804,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
|
||||
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
|
||||
vals=[[1., 2., 3.]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipIf(torch._C._get_privateuse1_backend_name() == "tiny", 'results in a success instead of a failure')
|
||||
def test_gather_failure(self):
|
||||
# gather with inf values do not work, other values results in nan
|
||||
# gather with inf values
|
||||
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),
|
||||
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
|
||||
vals=[[-float("inf"), 2., 3.]])
|
||||
|
||||
@@ -1303,7 +1303,7 @@ class Tensor(MathTrait):
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
index = index.to(self.device)
|
||||
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, dtype=self.dtype)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user