Files
tinygrad/test/unit/test_setitem_schedule.py
chenyu ca68037f26 lazy basic setitem to unrealized Tensor (#14756)
undo the view and make it a mask, this fuses the setitem with any pending compute too.

one behavior change is that for target not backed by a buffer (const and arange), rangeify makes output contiguous under the hood.
this is stricter better than raise and ask user to call contiguous, as that would no longer be fuse-able.
2026-02-14 20:27:03 -05:00

151 lines
5.5 KiB
Python

import unittest
from tinygrad import Tensor, dtypes, GlobalCounters
class TestSetitemInto(unittest.TestCase):
def test_setitem_into_unrealized(self):
GlobalCounters.reset()
t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 16)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
def test_setitem_into_unrealized_sliced_compute(self):
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
GlobalCounters.reset()
a = Tensor.arange(8, dtype=dtypes.int32).reshape(2, 4)
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10]
self.assertEqual(GlobalCounters.kernel_count, 0)
w[1] = 99
self.assertEqual(GlobalCounters.kernel_count, 0)
w.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
def test_setitem_into_empty(self):
GlobalCounters.reset()
t = Tensor.empty(4, dtype=dtypes.int32)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
# TODO: this can be just 4 if empty goes through is_realized setitem path
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(t[1].item(), 5)
def test_setitem_into_empty_alu(self):
GlobalCounters.reset()
t = Tensor.empty(4, dtype=dtypes.int32) + 1
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(t[1].item(), 5)
def test_setitem_into_tensor(self):
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize()
GlobalCounters.reset()
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [1, 5, 3, 4])
def test_setitem_into_tensor_alu(self):
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1
GlobalCounters.reset()
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [2, 5, 4, 5])
def test_setitem_into_cont(self):
GlobalCounters.reset()
t = Tensor.ones(4, dtype=dtypes.int32)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [1, 5, 1, 1])
def test_setitem_into_const_alu(self):
GlobalCounters.reset()
t = Tensor.ones(4, dtype=dtypes.int32) + 1
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4*4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
def test_setitem_into_arange(self):
# NOTE: arange has no real buffer, but assigning to it is fine
GlobalCounters.reset()
t = Tensor.arange(4, dtype=dtypes.int32)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
def test_setitem_slice_const(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[20:50] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
def test_setitem_slice_tensor(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
v = Tensor.zeros(30, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[20:50] = v
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
def test_setitem_full(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[:] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
if __name__ == '__main__':
unittest.main()