mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
lazy basic setitem to unrealized Tensor (#14756)
undo the view and make it a mask, this fuses the setitem with any pending compute too. one behavior change is that for target not backed by a buffer (const and arange), rangeify makes output contiguous under the hood. this is stricter better than raise and ask user to call contiguous, as that would no longer be fuse-able.
This commit is contained in:
@@ -1323,6 +1323,55 @@ class TestMultiAssign(unittest.TestCase):
|
||||
f(out, vi.bind(i))
|
||||
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiSetitem(unittest.TestCase):
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
|
||||
|
||||
@needs_second_gpu
|
||||
def setUp(self): pass
|
||||
|
||||
def _t(self, axis): return Tensor.arange(16).contiguous().realize().shard(self.device, axis=axis)
|
||||
|
||||
def test_setitem_scalar_axis0(self):
|
||||
t = self._t(0)
|
||||
t[1] = 99
|
||||
self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
|
||||
|
||||
def test_setitem_scalar_axis_none(self):
|
||||
t = self._t(None)
|
||||
t[1] = 99
|
||||
self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
|
||||
|
||||
def test_setitem_slice_cross_shard(self):
|
||||
t = self._t(0)
|
||||
t[2:6] = 99
|
||||
self.assertListEqual(t.tolist(), [0,1,99,99,99,99,6,7,8,9,10,11,12,13,14,15])
|
||||
|
||||
def test_setitem_full_slice(self):
|
||||
t = self._t(0)
|
||||
t[:] = 42
|
||||
self.assertListEqual(t.tolist(), [42]*16)
|
||||
|
||||
def test_setitem_stride(self):
|
||||
t = self._t(0)
|
||||
t[::4] = 0
|
||||
self.assertListEqual(t.tolist(), [0,1,2,3,0,5,6,7,0,9,10,11,0,13,14,15])
|
||||
|
||||
def test_setitem_single_shard(self):
|
||||
t = self._t(0)
|
||||
t[13] = 99
|
||||
self.assertListEqual(t.tolist(), [0,1,2,3,4,5,6,7,8,9,10,11,12,99,14,15])
|
||||
|
||||
def test_setitem_tensor_value_replicated(self):
|
||||
t = self._t(0)
|
||||
t[2:6] = Tensor([90, 91, 92, 93]).shard(self.device)
|
||||
self.assertListEqual(t.tolist(), [0,1,90,91,92,93,6,7,8,9,10,11,12,13,14,15])
|
||||
|
||||
def test_setitem_tensor_value_sharded_aligned(self):
|
||||
t = self._t(0)
|
||||
t[::4] = Tensor([90, 91, 92, 93]).shard(self.device, axis=0)
|
||||
self.assertListEqual(t.tolist(), [90,1,2,3,91,5,6,7,92,9,10,11,93,13,14,15])
|
||||
|
||||
@unittest.skipIf(not_support_multi_device(), "need multi")
|
||||
class TestMultiTransformer(unittest.TestCase):
|
||||
@needs_second_gpu
|
||||
|
||||
@@ -7,11 +7,13 @@ class TestSetitemInto(unittest.TestCase):
|
||||
t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4+4*2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 16)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
|
||||
|
||||
def test_setitem_into_unrealized_sliced_compute(self):
|
||||
@@ -21,17 +23,21 @@ class TestSetitemInto(unittest.TestCase):
|
||||
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10]
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
w[1] = 99
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4+4)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
w.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
|
||||
|
||||
def test_setitem_into_empty(self):
|
||||
GlobalCounters.reset()
|
||||
t = Tensor.empty(4, dtype=dtypes.int32)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4)
|
||||
# TODO: this can be just 4 if empty goes through is_realized setitem path
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
@@ -42,11 +48,13 @@ class TestSetitemInto(unittest.TestCase):
|
||||
t = Tensor.empty(4, dtype=dtypes.int32) + 1
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4*2+4)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(t[1].item(), 5)
|
||||
|
||||
def test_setitem_into_tensor(self):
|
||||
@@ -65,44 +73,49 @@ class TestSetitemInto(unittest.TestCase):
|
||||
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1
|
||||
GlobalCounters.reset()
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4*2+4)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t[1].realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(t.tolist(), [2, 5, 4, 5])
|
||||
|
||||
def test_setitem_into_cont(self):
|
||||
t = Tensor.ones(4, dtype=dtypes.int32)
|
||||
with self.assertRaises(RuntimeError): t[1] = 5
|
||||
|
||||
def test_setitem_into_const_alu(self):
|
||||
# TODO: this is not consistent
|
||||
GlobalCounters.reset()
|
||||
t = Tensor.ones(4, dtype=dtypes.int32) + 1
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t = Tensor.ones(4, dtype=dtypes.int32)
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4+4)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(t.tolist(), [1, 5, 1, 1])
|
||||
|
||||
def test_setitem_into_const_alu(self):
|
||||
GlobalCounters.reset()
|
||||
t = Tensor.ones(4, dtype=dtypes.int32) + 1
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
with self.assertRaises(RuntimeError): t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
||||
t[1].realize()
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
|
||||
|
||||
def test_setitem_into_arange(self):
|
||||
# NOTE: arange has no real buffer, but assigning to it is fine
|
||||
GlobalCounters.reset()
|
||||
t = Tensor.arange(4, dtype=dtypes.int32)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t[1] = 5
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
t[1].realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
|
||||
|
||||
def test_setitem_slice_const(self):
|
||||
|
||||
@@ -43,7 +43,6 @@ def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
|
||||
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
|
||||
# base already realized: copy src only if it reads from the same buffer (overlapping read/write hazard)
|
||||
if t.op is Ops.CONTIGUOUS: return assign.replace(src=(target, src.contiguous())) if t in src.toposort() else None
|
||||
if t.op is Ops.CONST: raise RuntimeError("setitem target must be a writable view backed by a buffer")
|
||||
mops: list[UOp] = []
|
||||
while target.op in GroupOp.Movement:
|
||||
mops.append(target)
|
||||
|
||||
@@ -1214,6 +1214,26 @@ class Tensor(OpMixin):
|
||||
x_dims = [p for p in indices_parsed if not isinstance(p['index'], sint)]
|
||||
x = x.reshape(tuple(p['size'] for p in x_dims))
|
||||
|
||||
# basic setitem: construct result with view region replaced by v using arange masks
|
||||
if v is not None and not any(isinstance(p['index'], Tensor) for p in indices_parsed):
|
||||
# broadcast v to getitem shape, reshape to self.ndim (squeeze None dims, unsqueeze int dims — all are size 1)
|
||||
vb = v.cast(self.dtype)._broadcast_to(x.shape)
|
||||
vb = vb.reshape(tuple(1 if isinstance(p['index'], sint) else p['size'] for p in indices_parsed if p['index'] is not None))
|
||||
# undo movement ops per-dim and build boolean mask
|
||||
per_dim = []
|
||||
for d, m in enumerate(mops):
|
||||
(s, e), st = m['boundary'], abs(m['stride'])
|
||||
if st != 1 and vb.shape[d] > 1: # un-stride: interleave with zeros
|
||||
vb = vb.unsqueeze(d+1)
|
||||
vb = vb.pad_to(tuple(st if j == d+1 else None for j in range(vb.ndim)))
|
||||
vb = vb.reshape(vb.shape[:d] + (vb.shape[d]*vb.shape[d+1],) + vb.shape[d+2:])
|
||||
vb = vb.shrink_to(tuple(e-s if j == d else None for j in range(self.ndim)))
|
||||
idx = Tensor.arange(self.shape[d], device=self.device).reshape([1]*d + [self.shape[d]] + [1]*(self.ndim - d - 1))
|
||||
per_dim.append((idx >= s) & (idx < e) & (((e-1-idx) if m['stride'] < 0 else (idx-s)) % st == 0))
|
||||
vb = vb.flip(tuple(d for d, m in enumerate(mops) if m['stride'] < 0))
|
||||
vb = vb.pad(tuple((m['boundary'][0], self.shape[d] - m['boundary'][1]) for d, m in enumerate(mops)))
|
||||
return (functools.reduce(lambda a, b: a & b, per_dim) if per_dim else Tensor(True, dtype=dtypes.bool, device=self.device)).where(vb, self)
|
||||
|
||||
# tensor indexing
|
||||
if tops := [(d, p) for d, p in enumerate(x_dims) if isinstance(p['index'], Tensor)]:
|
||||
dims, tensors, masks = [d for d, _ in tops], cast(list[Tensor], [p['index'] for _, p in tops]), []
|
||||
@@ -1312,7 +1332,10 @@ class Tensor(OpMixin):
|
||||
elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized
|
||||
self[indices].assign(v)
|
||||
else: # basic setitem, self is not realized
|
||||
self[indices].assign(v).realize()
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
# __iadd__/__isub__ on unrealized views creates a no-op ASSIGN; unwrap to get the computed value
|
||||
if v.uop.op is Ops.ASSIGN: v = v._apply_uop(lambda x: x.src[1])
|
||||
self.replace(self._getitem(indices, v))
|
||||
|
||||
def __delitem__(self, indices) -> None:
|
||||
raise TypeError("Tensor does not support deleting items")
|
||||
|
||||
Reference in New Issue
Block a user