diff --git a/test/backend/test_multitensor.py b/test/backend/test_multitensor.py index 1f2ce92e0c..0451c87a10 100644 --- a/test/backend/test_multitensor.py +++ b/test/backend/test_multitensor.py @@ -1323,6 +1323,55 @@ class TestMultiAssign(unittest.TestCase): f(out, vi.bind(i)) self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4) +@unittest.skipIf(not_support_multi_device(), "need multi") +class TestMultiSetitem(unittest.TestCase): + device = tuple(f"{Device.DEFAULT}:{i}" for i in range(4)) + + @needs_second_gpu + def setUp(self): pass + + def _t(self, axis): return Tensor.arange(16).contiguous().realize().shard(self.device, axis=axis) + + def test_setitem_scalar_axis0(self): + t = self._t(0) + t[1] = 99 + self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + + def test_setitem_scalar_axis_none(self): + t = self._t(None) + t[1] = 99 + self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15]) + + def test_setitem_slice_cross_shard(self): + t = self._t(0) + t[2:6] = 99 + self.assertListEqual(t.tolist(), [0,1,99,99,99,99,6,7,8,9,10,11,12,13,14,15]) + + def test_setitem_full_slice(self): + t = self._t(0) + t[:] = 42 + self.assertListEqual(t.tolist(), [42]*16) + + def test_setitem_stride(self): + t = self._t(0) + t[::4] = 0 + self.assertListEqual(t.tolist(), [0,1,2,3,0,5,6,7,0,9,10,11,0,13,14,15]) + + def test_setitem_single_shard(self): + t = self._t(0) + t[13] = 99 + self.assertListEqual(t.tolist(), [0,1,2,3,4,5,6,7,8,9,10,11,12,99,14,15]) + + def test_setitem_tensor_value_replicated(self): + t = self._t(0) + t[2:6] = Tensor([90, 91, 92, 93]).shard(self.device) + self.assertListEqual(t.tolist(), [0,1,90,91,92,93,6,7,8,9,10,11,12,13,14,15]) + + def test_setitem_tensor_value_sharded_aligned(self): + t = self._t(0) + t[::4] = Tensor([90, 91, 92, 93]).shard(self.device, axis=0) + self.assertListEqual(t.tolist(), [90,1,2,3,91,5,6,7,92,9,10,11,93,13,14,15]) + @unittest.skipIf(not_support_multi_device(), "need multi") class TestMultiTransformer(unittest.TestCase): @needs_second_gpu diff --git a/test/unit/test_setitem_schedule.py b/test/unit/test_setitem_schedule.py index 7bd8ad8765..ac63ef1162 100644 --- a/test/unit/test_setitem_schedule.py +++ b/test/unit/test_setitem_schedule.py @@ -7,11 +7,13 @@ class TestSetitemInto(unittest.TestCase): t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2) self.assertEqual(GlobalCounters.kernel_count, 0) t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertEqual(GlobalCounters.global_mem, 4*4+4*2) + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 16) t[1].realize() t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 1) self.assertListEqual(t.tolist(), [[0, 1], [5, 5]]) def test_setitem_into_unrealized_sliced_compute(self): @@ -21,17 +23,21 @@ class TestSetitemInto(unittest.TestCase): w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10] self.assertEqual(GlobalCounters.kernel_count, 0) w[1] = 99 - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertEqual(GlobalCounters.global_mem, 4*4+4) + self.assertEqual(GlobalCounters.kernel_count, 0) + w.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4*4) self.assertListEqual(w.tolist(), [4, 99, 8, 10]) def test_setitem_into_empty(self): GlobalCounters.reset() t = Tensor.empty(4, dtype=dtypes.int32) - self.assertEqual(GlobalCounters.kernel_count, 0) t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() self.assertEqual(GlobalCounters.kernel_count, 1) - self.assertEqual(GlobalCounters.global_mem, 4) + # TODO: this can be just 4 if empty goes through is_realized setitem path + self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly t[1].realize() t.realize() self.assertEqual(GlobalCounters.kernel_count, 1) @@ -42,11 +48,13 @@ class TestSetitemInto(unittest.TestCase): t = Tensor.empty(4, dtype=dtypes.int32) + 1 self.assertEqual(GlobalCounters.kernel_count, 0) t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertEqual(GlobalCounters.global_mem, 4*4*2+4) + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly t[1].realize() t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 1) self.assertEqual(t[1].item(), 5) def test_setitem_into_tensor(self): @@ -65,44 +73,49 @@ class TestSetitemInto(unittest.TestCase): t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1 GlobalCounters.reset() t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertEqual(GlobalCounters.global_mem, 4*4*2+4) + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1].realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly t[1].realize() t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 1) self.assertListEqual(t.tolist(), [2, 5, 4, 5]) def test_setitem_into_cont(self): - t = Tensor.ones(4, dtype=dtypes.int32) - with self.assertRaises(RuntimeError): t[1] = 5 - - def test_setitem_into_const_alu(self): - # TODO: this is not consistent GlobalCounters.reset() - t = Tensor.ones(4, dtype=dtypes.int32) + 1 - self.assertEqual(GlobalCounters.kernel_count, 0) + t = Tensor.ones(4, dtype=dtypes.int32) t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertEqual(GlobalCounters.global_mem, 4*4+4) + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4*4) t[1].realize() t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertListEqual(t.tolist(), [2, 5, 2, 2]) + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertListEqual(t.tolist(), [1, 5, 1, 1]) + def test_setitem_into_const_alu(self): + GlobalCounters.reset() t = Tensor.ones(4, dtype=dtypes.int32) + 1 + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 0) t.realize() - with self.assertRaises(RuntimeError): t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4*4) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertListEqual(t.tolist(), [2, 5, 2, 2]) def test_setitem_into_arange(self): # NOTE: arange has no real buffer, but assigning to it is fine GlobalCounters.reset() t = Tensor.arange(4, dtype=dtypes.int32) - self.assertEqual(GlobalCounters.kernel_count, 0) t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - t[1].realize() + self.assertEqual(GlobalCounters.kernel_count, 0) t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.kernel_count, 1) self.assertListEqual(t.tolist(), [0, 5, 2, 3]) def test_setitem_slice_const(self): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 82da3cc366..6211afb03f 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -43,7 +43,6 @@ def assign_to_contiguous(assign:UOp, target:UOp, src:UOp): if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK): # base already realized: copy src only if it reads from the same buffer (overlapping read/write hazard) if t.op is Ops.CONTIGUOUS: return assign.replace(src=(target, src.contiguous())) if t in src.toposort() else None - if t.op is Ops.CONST: raise RuntimeError("setitem target must be a writable view backed by a buffer") mops: list[UOp] = [] while target.op in GroupOp.Movement: mops.append(target) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c45300e0c9..4bac773cc9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1214,6 +1214,26 @@ class Tensor(OpMixin): x_dims = [p for p in indices_parsed if not isinstance(p['index'], sint)] x = x.reshape(tuple(p['size'] for p in x_dims)) + # basic setitem: construct result with view region replaced by v using arange masks + if v is not None and not any(isinstance(p['index'], Tensor) for p in indices_parsed): + # broadcast v to getitem shape, reshape to self.ndim (squeeze None dims, unsqueeze int dims — all are size 1) + vb = v.cast(self.dtype)._broadcast_to(x.shape) + vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None)) + # undo movement ops per-dim and build boolean mask + per_dim = [] + for d, m in enumerate(mops): + (s, e), st = m['boundary'], abs(m['stride']) + if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros + vb = vb.unsqueeze(d+1) + vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim))) + vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:]) + vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim))) + idx = Tensor.arange(self.shape[d], device=self.device).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1)) + per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0)) + vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0)) + vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops))) + return (functools.reduce(lambda a, b: a & b, per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self) + # tensor indexing if tops := [(d, p) for d, p in enumerate(x_dims) if isinstance(p['index'], Tensor)]: dims, tensors, masks = [d for d, _ in tops], cast(list[Tensor], [p['index'] for _, p in tops]), [] @@ -1312,7 +1332,10 @@ class Tensor(OpMixin): elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized self[indices].assign(v) else: # basic setitem, self is not realized - self[indices].assign(v).realize() + if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) + # __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value + if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1]) + self.replace(self._getitem(indices, v)) def __delitem__(self, indices) -> None: raise TypeError("Tensor does not support deleting items")