Fix advanced tensor indexing setitem (#12128)

* Add failure test case for advanced tensor indexing setitem

* Fix advanced tensor indexing setitem when permuted

* Reduce line count

* Revert unnecessary change

* Combine two lines into one
This commit is contained in:
Shun Usami
2025-09-14 12:22:40 -07:00
committed by GitHub
parent d09c0f28c5
commit 34a05b31fe
2 changed files with 14 additions and 3 deletions

View File

@@ -165,6 +165,17 @@ class TestSetitem(unittest.TestCase):
t[idx] = val
self.assertEqual(t.tolist(), [val]*idx_size+[idx_size])
def test_setitem_advanced_indexing(self):
# Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
t = Tensor.zeros(10,20,30,40,50).contiguous()
ind_1 = Tensor([5,3,7,8])
ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]])
v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50)
t[:, ind_1, :, ind_2, :] = v
n = np.zeros((10,20,30,40,50))
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)

View File

@@ -1220,8 +1220,8 @@ class Tensor(MathTrait):
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
# special permute case
if dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1)):
x = x.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), x.ndim))
if (permuted := dims[0] != 0 and len(dims) != 1 and tuple(dims) != tuple(range(dims[0], dims[-1]+1))):
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
# for advanced setitem, returns whole tensor with indices replaced
if v is not None:
@@ -1229,7 +1229,7 @@ class Tensor(MathTrait):
# add back reduced dims from sum
for dim in sum_axis: vb = vb.unsqueeze(dim)
# run _masked_setitem on tuple of axis that is to be reduced to match self.shape
x = _masked_setitem(self, vb, mask, tuple(range(dims[0], dims[0] + len(big_shape))))
x = _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape))))
return x