mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix mixed setitem with both basic and tensor indexing (#15050)
This commit is contained in:
@@ -205,7 +205,6 @@ class TestSetitem(unittest.TestCase):
|
||||
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
|
||||
np.testing.assert_equal(t.numpy(), n)
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_setitem_tensor_int_indexing(self):
|
||||
t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous()
|
||||
t[Tensor([0, 2]), 0] = Tensor([99, 88], dtype=dtypes.int)
|
||||
@@ -213,7 +212,6 @@ class TestSetitem(unittest.TestCase):
|
||||
n[[0, 2], 0] = [99, 88]
|
||||
np.testing.assert_equal(t.numpy(), n)
|
||||
|
||||
@unittest.expectedFailure # TODO: fix
|
||||
def test_setitem_tensor_slice_indexing(self):
|
||||
t = Tensor.zeros(4, 3, dtype=dtypes.int).contiguous()
|
||||
t[Tensor([0, 2]), :2] = Tensor([[10, 20], [30, 40]], dtype=dtypes.int)
|
||||
|
||||
@@ -1248,6 +1248,7 @@ class Tensor(OpMixin):
|
||||
# inject 1's for the extra dims added in create masks
|
||||
reshape_arg = x.shape[:dims[0]] + (1,) * len(big_shape) + x.shape[dims[0]:]
|
||||
# sum reduce the extra dims introduced in create masks
|
||||
x_pre = x # save collapsed shape for advanced setitem
|
||||
x = (mask.where(x.reshape(reshape_arg), 0)).sum(sum_axis:=tuple(d + len(big_shape) for d in dims), dtype=x.dtype)
|
||||
|
||||
# special permute case
|
||||
@@ -1255,14 +1256,14 @@ class Tensor(OpMixin):
|
||||
mask, x = (y.permute(*range(dims[0], dims[0]+len(big_shape)), *range(0, dims[0]), *range(dims[0]+len(big_shape), y.ndim)) for y in (mask, x))
|
||||
|
||||
if v is None: return x # advanced getitem
|
||||
# advanced setitem
|
||||
# advanced setitem: resolve tensor dims in collapsed space, then fall through to basic setitem path
|
||||
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(x.shape, v.shape))
|
||||
for dim in sum_axis: vb = vb.unsqueeze(dim) # add back reduced dims from sum
|
||||
return _masked_setitem(self, vb, mask, tuple(range((start := dims[0] if not permuted else 0), start + len(big_shape))))
|
||||
|
||||
if v is None: return x # basic getitem
|
||||
start = dims[0] if not permuted else 0
|
||||
vb = _masked_setitem(x_pre, vb, mask, tuple(range(start, start + len(big_shape))))
|
||||
elif v is None: return x # basic getitem
|
||||
# basic setitem: broadcast v, reshape to self.ndim (unsqueeze int dims, squeeze None dims)
|
||||
vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
else: vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None))
|
||||
per_dim = []
|
||||
for d, m in enumerate(mops):
|
||||
|
||||
Reference in New Issue
Block a user