From db6b3e1edc191f6aa5be7ba7ae03a67e27aa9224 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 27 Feb 2026 15:35:48 -0500 Subject: [PATCH] fix mixed setitem with both basic and tensor indexing (#15050) --- test/backend/test_setitem.py | 2 -- tinygrad/tensor.py | 11 ++++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index a4edac2e92..e4d43463a0 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -205,7 +205,6 @@ class TestSetitem(unittest.TestCase): n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy() np.testing.assert_equal(t.numpy(), n) - @unittest.expectedFailure # TODO: fix def test_setitem_tensor_int_indexing(self): t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous() t[Tensor([0, 2]), 0] = Tensor([99, 88], dtype=dtypes.int) @@ -213,7 +212,6 @@ class TestSetitem(unittest.TestCase): n[[0, 2], 0] = [99, 88] np.testing.assert_equal(t.numpy(), n) - @unittest.expectedFailure # TODO: fix def test_setitem_tensor_slice_indexing(self): t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous() t[Tensor([0, 2]), :2] = Tensor([[10, 20], [30, 40]], dtype=dtypes.int) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5f45ae6924..8907add7c4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1248,6 +1248,7 @@ class Tensor(OpMixin): # inject 1's for the extra dims added in create masks reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:] # sum reduce the extra dims introduced in create masks + x_pre = x # save collapsed shape for advanced setitem x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype) # special permute case @@ -1255,14 +1256,14 @@ class Tensor(OpMixin): mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x)) if v is None: return x # advanced getitem - # advanced setitem + # advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape)) for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum - return _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape)))) - - if v is None: return x # basic getitem + start = dims[0] if not permuted else 0 + vb = _masked_setitem(x_pre, vb, mask, tuple(range(start, start + len(big_shape)))) + elif v is None: return x # basic getitem # basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims) - vb = v.cast(self.dtype)._broadcast_to(x.shape) + else: vb = v.cast(self.dtype)._broadcast_to(x.shape) vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None)) per_dim = [] for d, m in enumerate(mops):