diff --git a/test/test_setitem.py b/test/test_setitem.py index 967acc29f1..54ae9007af 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -165,6 +165,17 @@ class TestSetitem(unittest.TestCase): t[idx] = val self.assertEqual(t.tolist(), [val]*idx_size+[idx_size]) + def test_setitem_advanced_indexing(self): + # Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing + t = Tensor.zeros(10,20,30,40,50).contiguous() + ind_1 = Tensor([5,3,7,8]) + ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]]) + v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50) + t[:, ind_1, :, ind_2, :] = v + n = np.zeros((10,20,30,40,50)) + n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy() + np.testing.assert_allclose(t.numpy(), n) + class TestWithGrad(unittest.TestCase): def test_no_requires_grad_works(self): z = Tensor.rand(8, 8) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fb37bd7b18..b83d751282 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1220,8 +1220,8 @@ class Tensor(MathTrait): x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) # special permute case - if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)): - x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim)) + if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))): + mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x)) # for advanced setitem, returns whole tensor with indices replaced if v is not None: @@ -1229,7 +1229,7 @@ class Tensor(MathTrait): # add back reduced dims from sum for dim in sum_axis: vb = vb.unsqueeze(dim) # run _masked_setitem on tuple of axis that is to be reduced to match self.shape - x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape)))) + x = _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape)))) return x