From 8f6772fd8ce680bc7ad507160d5f3ebedc904856 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 14 Feb 2026 11:01:03 -0500 Subject: [PATCH] more setitem kernel mem tests (#14749) * more setitem kernel mem tests test only the slice is accessed * update --- test/backend/test_setitem.py | 70 +----------------- test/unit/test_setitem_schedule.py | 109 +++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 69 deletions(-) create mode 100644 test/unit/test_setitem_schedule.py diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index aee93dc775..4626d2875b 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor, TinyJit, Variable, dtypes, Device, GlobalCounters +from tinygrad import Tensor, TinyJit, Variable, dtypes, Device from tinygrad.helpers import Context import numpy as np @@ -36,18 +36,6 @@ class TestSetitem(unittest.TestCase): t[:3] *= 10 self.assertListEqual(t.tolist(), [0, 10, 20, 3, 4, 5, 6, 7, 8, 9]) - def test_setitem_into_unrealized(self): - t = Tensor.arange(4).reshape(2, 2) - t[1] = 5 - np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]]) - - def test_setitem_into_unrealized_sliced_compute(self): - # base computation contains SHRINK from prior slicing (like QR decomposition pattern) - a = Tensor.arange(6, dtype=dtypes.float).reshape(2, 3) - w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [3, 5, 7] - w[1] = 99 - np.testing.assert_allclose(w.numpy(), [3, 99, 7]) - def test_setitem_fancy_on_unrealized_view(self): # fancy indexing setitem on unrealized SHRINK view (triggered infinite loop in graph_rewrite) base = Tensor.arange(20, dtype=dtypes.float).reshape(4, 5) @@ -69,62 +57,6 @@ class TestSetitem(unittest.TestCase): t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize() with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int) - def test_setitem_into_empty(self): - GlobalCounters.reset() - t = Tensor.empty(4) - self.assertEqual(GlobalCounters.kernel_count, 0) - t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 1) - t[1].realize() - self.assertEqual(GlobalCounters.kernel_count, 1) - t.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) - self.assertEqual(t[1].item(), 5) - - def test_setitem_into_tensor(self): - t = Tensor([1, 2, 3, 4]).realize() - GlobalCounters.reset() - self.assertEqual(GlobalCounters.kernel_count, 0) - t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 0) - t[1].realize() - self.assertEqual(GlobalCounters.kernel_count, 1) - t.realize() - self.assertEqual(GlobalCounters.kernel_count, 1) - self.assertListEqual(t.tolist(), [1, 5, 3, 4]) - - def test_setitem_into_cont(self): - t = Tensor.ones(4) - with self.assertRaises(RuntimeError): t[1] = 5 - - def test_setitem_into_const_alu(self): - # TODO: this is not consistent - GlobalCounters.reset() - t = Tensor.ones(4) + Tensor.ones(4) - self.assertEqual(GlobalCounters.kernel_count, 0) - t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - t[1].realize() - self.assertEqual(GlobalCounters.kernel_count, 2) - t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertListEqual(t.tolist(), [2, 5, 2, 2]) - - t = Tensor.ones(4) + Tensor.ones(4) - t.realize() - with self.assertRaises(RuntimeError): t[1] = 5 - - def test_setitem_into_arange(self): - # NOTE: arange has no real buffer, but assigning to it is fine - GlobalCounters.reset() - t = Tensor.arange(4) - self.assertEqual(GlobalCounters.kernel_count, 0) - t[1] = 5 - self.assertEqual(GlobalCounters.kernel_count, 2) - t.realize() - self.assertEqual(GlobalCounters.kernel_count, 2) - self.assertListEqual(t.tolist(), [0, 5, 2, 3]) - def test_setitem_chained_indexing(self): # N[i][j] must work the same as N[i, j] N1 = Tensor.zeros((3, 3)).contiguous().realize() diff --git a/test/unit/test_setitem_schedule.py b/test/unit/test_setitem_schedule.py new file mode 100644 index 0000000000..c583c316be --- /dev/null +++ b/test/unit/test_setitem_schedule.py @@ -0,0 +1,109 @@ +import unittest +from tinygrad import Tensor, dtypes, GlobalCounters + +class TestSetitemInto(unittest.TestCase): + def test_setitem_into_unrealized(self): + GlobalCounters.reset() + t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2) + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.global_mem, 4*4+4*2) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertListEqual(t.tolist(), [[0, 1], [5, 5]]) + + def test_setitem_into_unrealized_sliced_compute(self): + # base computation contains SHRINK from prior slicing (like QR decomposition pattern) + GlobalCounters.reset() + a = Tensor.arange(8, dtype=dtypes.int32).reshape(2, 4) + w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10] + self.assertEqual(GlobalCounters.kernel_count, 0) + w[1] = 99 + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.global_mem, 4*4+4) + self.assertListEqual(w.tolist(), [4, 99, 8, 10]) + + def test_setitem_into_empty(self): + GlobalCounters.reset() + t = Tensor.empty(4, dtype=dtypes.int32) + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(t[1].item(), 5) + + def test_setitem_into_empty_alu(self): + GlobalCounters.reset() + t = Tensor.empty(4, dtype=dtypes.int32) + 1 + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.global_mem, 4*4*2+4) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(t[1].item(), 5) + + def test_setitem_into_tensor(self): + t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + GlobalCounters.reset() + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1].realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 4) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertListEqual(t.tolist(), [1, 5, 3, 4]) + + def test_setitem_into_tensor_alu(self): + t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1 + GlobalCounters.reset() + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.global_mem, 4*4*2+4) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertListEqual(t.tolist(), [2, 5, 4, 5]) + + def test_setitem_into_cont(self): + t = Tensor.ones(4, dtype=dtypes.int32) + with self.assertRaises(RuntimeError): t[1] = 5 + + def test_setitem_into_const_alu(self): + # TODO: this is not consistent + GlobalCounters.reset() + t = Tensor.ones(4, dtype=dtypes.int32) + 1 + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertEqual(GlobalCounters.global_mem, 4*4+4) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertListEqual(t.tolist(), [2, 5, 2, 2]) + + t = Tensor.ones(4, dtype=dtypes.int32) + 1 + t.realize() + with self.assertRaises(RuntimeError): t[1] = 5 + + def test_setitem_into_arange(self): + # NOTE: arange has no real buffer, but assigning to it is fine + GlobalCounters.reset() + t = Tensor.arange(4, dtype=dtypes.int32) + self.assertEqual(GlobalCounters.kernel_count, 0) + t[1] = 5 + self.assertEqual(GlobalCounters.kernel_count, 2) + t[1].realize() + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 2) + self.assertListEqual(t.tolist(), [0, 5, 2, 3]) + +if __name__ == '__main__': + unittest.main()