fix mixed setitem with both basic and tensor indexing (#15050)

This commit is contained in:
chenyu
2026-02-27 15:35:48 -05:00
committed by GitHub
parent c9f6d8751b
commit db6b3e1edc
2 changed files with 6 additions and 7 deletions

View File

@@ -205,7 +205,6 @@ class TestSetitem(unittest.TestCase):
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
np.testing.assert_equal(t.numpy(), n)
@unittest.expectedFailure # TODO: fix
def test_setitem_tensor_int_indexing(self):
t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous()
t[Tensor([0, 2]), 0] = Tensor([99, 88], dtype=dtypes.int)
@@ -213,7 +212,6 @@ class TestSetitem(unittest.TestCase):
n[[0, 2], 0] = [99, 88]
np.testing.assert_equal(t.numpy(), n)
@unittest.expectedFailure # TODO: fix
def test_setitem_tensor_slice_indexing(self):
t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous()
t[Tensor([0, 2]), :2] = Tensor([[10, 20], [30, 40]], dtype=dtypes.int)

View File

@@ -1248,6 +1248,7 @@ class Tensor(OpMixin):
# inject 1's for the extra dims added in create masks
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
# sum reduce the extra dims introduced in create masks
x_pre = x # save collapsed shape for advanced setitem
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
# special permute case
@@ -1255,14 +1256,14 @@ class Tensor(OpMixin):
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
if v is None: return x # advanced getitem
# advanced setitem
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
return _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape))))
if v is None: return x # basic getitem
start = dims[0] if not permuted else 0
vb = _masked_setitem(x_pre, vb, mask, tuple(range(start, start + len(big_shape))))
elif v is None: return x # basic getitem
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
vb = v.cast(self.dtype)._broadcast_to(x.shape)
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None))
per_dim = []
for d, m in enumerate(mops):