diff --git a/test/test_setitem.py b/test/test_setitem.py index d03d28322d..270bfa8c89 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -89,10 +89,7 @@ class TestSetitem(unittest.TestCase): t[[1,1]] = Tensor([0,1]) np.testing.assert_allclose(t.numpy(), [1,1,3,4]) - # TODO: #7739 fix when setting value 0 to overlapping indices - # error occurs when previous overlapped values are non-zero and last overlapping value is zero - @unittest.expectedFailure - def test_setitem_overlapping_indices_failure(self): + def test_setitem_overlapping_indices_with_0(self): t = Tensor([1,2,3,4]) t[[1,1]] = Tensor([1,0]) np.testing.assert_allclose(t.numpy(), [1,0,3,4]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 602a692a1e..63bd57cf03 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1192,11 +1192,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method for dim in sum_axis: vb = vb.unsqueeze(dim) # axis to be reduced to match self.shape axis = tuple(range(first_dim, first_dim + len(big_shape))) - # apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains + # apply mask to vb(broadcasted) and reduce such that if mask contains repeated indices the last one remains vb = vb * mask - for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim)) - # reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self - ret = mask.any(axis).where(vb.squeeze(), self) + for dim in axis: mask, vb = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), vb.split(1, dim))) + # select from vb for each True element in mask else select from self, squeeze to remove extra dims from reduce + ret = mask.squeeze().where(vb.squeeze(), self) return ret