mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
fix advanced setitem overlap with 0 (#7793)
* fix advanced setitem overlap with 0 * fix comment
This commit is contained in:
@@ -89,10 +89,7 @@ class TestSetitem(unittest.TestCase):
|
||||
t[[1,1]] = Tensor([0,1])
|
||||
np.testing.assert_allclose(t.numpy(), [1,1,3,4])
|
||||
|
||||
# TODO: #7739 fix when setting value 0 to overlapping indices
|
||||
# error occurs when previous overlapped values are non-zero and last overlapping value is zero
|
||||
@unittest.expectedFailure
|
||||
def test_setitem_overlapping_indices_failure(self):
|
||||
def test_setitem_overlapping_indices_with_0(self):
|
||||
t = Tensor([1,2,3,4])
|
||||
t[[1,1]] = Tensor([1,0])
|
||||
np.testing.assert_allclose(t.numpy(), [1,0,3,4])
|
||||
|
||||
@@ -1192,11 +1192,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
for dim in sum_axis: vb = vb.unsqueeze(dim)
|
||||
# axis to be reduced to match self.shape
|
||||
axis = tuple(range(first_dim, first_dim + len(big_shape)))
|
||||
# apply mask to v(broadcasted) and reduce such that if v contains repeated indices the last one remains
|
||||
# apply mask to vb(broadcasted) and reduce such that if mask contains repeated indices the last one remains
|
||||
vb = vb * mask
|
||||
for dim in axis: vb = functools.reduce(lambda x,y: y.where(y, x), vb.split(1, dim))
|
||||
# reduce mask and select from v(get rid of extra dims from reduce) for each True element in mask else select from self
|
||||
ret = mask.any(axis).where(vb.squeeze(), self)
|
||||
for dim in axis: mask, vb = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), vb.split(1, dim)))
|
||||
# select from vb for each True element in mask else select from self, squeeze to remove extra dims from reduce
|
||||
ret = mask.squeeze().where(vb.squeeze(), self)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user