fix advanced setitem overlap with 0 (#7793)

* fix advanced setitem overlap with 0

* fix comment
This commit is contained in:
ttomsa
2024-11-19 21:03:55 +00:00
committed by GitHub
parent 159c0bf25e
commit 170ece6605
2 changed files with 5 additions and 8 deletions

View File

@@ -89,10 +89,7 @@ class TestSetitem(unittest.TestCase):
t[[1,1]] = Tensor([0,1])
np.testing.assert_allclose(t.numpy(), [1,1,3,4])
# TODO: #7739 fix when setting value 0 to overlapping indices
# error occurs when previous overlapped values are non-zero and last overlapping value is zero
@unittest.expectedFailure
def test_setitem_overlapping_indices_failure(self):
def test_setitem_overlapping_indices_with_0(self):
t = Tensor([1,2,3,4])
t[[1,1]] = Tensor([1,0])
np.testing.assert_allclose(t.numpy(), [1,0,3,4])

View File

@@ -1192,11 +1192,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
for dim in sum_axis: vb = vb.unsqueeze(dim)
# axis to be reduced to match self.shape
axis = tuple(range(first_dim, first_dim + len(big_shape)))
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
# apply mask to vb(broadcasted) and reduce such that if mask contains repeated indices the last one remains
vb = vb * mask
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
ret = mask.any(axis).where(vb.squeeze(), self)
for dim in axis: mask, vb = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), vb.split(1, dim)))
# select from vb for each True element in mask else select from self, squeeze to remove extra dims from reduce
ret = mask.squeeze().where(vb.squeeze(), self)
return ret