fix getitem with inf in tensor (#11781)

This commit is contained in:
chenyu
2025-08-21 21:55:32 -04:00
committed by GitHub
parent 66e9d54eed
commit 91a4de4ca7
3 changed files with 6 additions and 1 deletions

View File

@@ -117,6 +117,7 @@ class TestLinearizer(unittest.TestCase):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skip("broken. should not depends on push_views and implementation details of getitem")
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
def test_indexing_multireduce(self):
dataset = Tensor.rand(16384, 256).realize()

View File

@@ -2694,6 +2694,10 @@ class TestOps(unittest.TestCase):
i, j, k, o, p = [Tensor(tor.detach().cpu().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]]
return a,b,c,d,e,i,j,k,o,p
def test_fancy_indexing_inf(self):
data = [math.inf, -math.inf, math.nan]
helper_test_op((), lambda: torch.tensor(data)[torch.tensor([0, 1, 2])], lambda: Tensor(data)[Tensor([0, 1, 2])])
def test_slice_fancy_indexing_no_dim_collapse(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
# no dim collapse from int or dim injection from None

View File

@@ -1212,7 +1212,7 @@ class Tensor(MathTrait):
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
# sum reduce the extra dims introduced in create masks
x = (x.reshape(reshape_arg) * mask).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
# special permute case
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):