Fix advanced tensor indexing setitem (#12128)

* Add failure test case for advanced tensor indexing setitem

* Fix advanced tensor indexing setitem when permuted

* Reduce line count

* Revert unnecessary change

* Combine two lines into one
This commit is contained in:
Shun Usami
2025-09-14 12:22:40 -07:00
committed by GitHub
parent d09c0f28c5
commit 34a05b31fe
2 changed files with 14 additions and 3 deletions

View File

@@ -165,6 +165,17 @@ class TestSetitem(unittest.TestCase):
t[idx] = val
self.assertEqual(t.tolist(), [val]*idx_size+[idx_size])
def test_setitem_advanced_indexing(self):
# Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
t = Tensor.zeros(10,20,30,40,50).contiguous()
ind_1 = Tensor([5,3,7,8])
ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]])
v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50)
t[:, ind_1, :, ind_2, :] = v
n = np.zeros((10,20,30,40,50))
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)