more setitem kernel mem tests (#14749)

* more setitem kernel mem tests

test only the slice is accessed

* update
This commit is contained in:
chenyu
2026-02-14 11:01:03 -05:00
committed by GitHub
parent 446909fb7a
commit 8f6772fd8c
2 changed files with 110 additions and 69 deletions

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, TinyJit, Variable, dtypes, Device, GlobalCounters
from tinygrad import Tensor, TinyJit, Variable, dtypes, Device
from tinygrad.helpers import Context
import numpy as np
@@ -36,18 +36,6 @@ class TestSetitem(unittest.TestCase):
t[:3] *= 10
self.assertListEqual(t.tolist(), [0, 10, 20, 3, 4, 5, 6, 7, 8, 9])
def test_setitem_into_unrealized(self):
t = Tensor.arange(4).reshape(2, 2)
t[1] = 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
def test_setitem_into_unrealized_sliced_compute(self):
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
a = Tensor.arange(6, dtype=dtypes.float).reshape(2, 3)
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [3, 5, 7]
w[1] = 99
np.testing.assert_allclose(w.numpy(), [3, 99, 7])
def test_setitem_fancy_on_unrealized_view(self):
# fancy indexing setitem on unrealized SHRINK view (triggered infinite loop in graph_rewrite)
base = Tensor.arange(20, dtype=dtypes.float).reshape(4, 5)
@@ -69,62 +57,6 @@ class TestSetitem(unittest.TestCase):
t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize()
with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int)
def test_setitem_into_empty(self):
GlobalCounters.reset()
t = Tensor.empty(4)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 1)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(t[1].item(), 5)
def test_setitem_into_tensor(self):
t = Tensor([1, 2, 3, 4]).realize()
GlobalCounters.reset()
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [1, 5, 3, 4])
def test_setitem_into_cont(self):
t = Tensor.ones(4)
with self.assertRaises(RuntimeError): t[1] = 5
def test_setitem_into_const_alu(self):
# TODO: this is not consistent
GlobalCounters.reset()
t = Tensor.ones(4) + Tensor.ones(4)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
t = Tensor.ones(4) + Tensor.ones(4)
t.realize()
with self.assertRaises(RuntimeError): t[1] = 5
def test_setitem_into_arange(self):
# NOTE: arange has no real buffer, but assigning to it is fine
GlobalCounters.reset()
t = Tensor.arange(4)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
def test_setitem_chained_indexing(self):
# N[i][j] must work the same as N[i, j]
N1 = Tensor.zeros((3, 3)).contiguous().realize()

View File

@@ -0,0 +1,109 @@
import unittest
from tinygrad import Tensor, dtypes, GlobalCounters
class TestSetitemInto(unittest.TestCase):
def test_setitem_into_unrealized(self):
GlobalCounters.reset()
t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.global_mem, 4*4+4*2)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
def test_setitem_into_unrealized_sliced_compute(self):
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
GlobalCounters.reset()
a = Tensor.arange(8, dtype=dtypes.int32).reshape(2, 4)
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10]
self.assertEqual(GlobalCounters.kernel_count, 0)
w[1] = 99
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.global_mem, 4*4+4)
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
def test_setitem_into_empty(self):
GlobalCounters.reset()
t = Tensor.empty(4, dtype=dtypes.int32)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(t[1].item(), 5)
def test_setitem_into_empty_alu(self):
GlobalCounters.reset()
t = Tensor.empty(4, dtype=dtypes.int32) + 1
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.global_mem, 4*4*2+4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(t[1].item(), 5)
def test_setitem_into_tensor(self):
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize()
GlobalCounters.reset()
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1].realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 4)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(t.tolist(), [1, 5, 3, 4])
def test_setitem_into_tensor_alu(self):
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1
GlobalCounters.reset()
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.global_mem, 4*4*2+4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [2, 5, 4, 5])
def test_setitem_into_cont(self):
t = Tensor.ones(4, dtype=dtypes.int32)
with self.assertRaises(RuntimeError): t[1] = 5
def test_setitem_into_const_alu(self):
# TODO: this is not consistent
GlobalCounters.reset()
t = Tensor.ones(4, dtype=dtypes.int32) + 1
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertEqual(GlobalCounters.global_mem, 4*4+4)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
t = Tensor.ones(4, dtype=dtypes.int32) + 1
t.realize()
with self.assertRaises(RuntimeError): t[1] = 5
def test_setitem_into_arange(self):
# NOTE: arange has no real buffer, but assigning to it is fine
GlobalCounters.reset()
t = Tensor.arange(4, dtype=dtypes.int32)
self.assertEqual(GlobalCounters.kernel_count, 0)
t[1] = 5
self.assertEqual(GlobalCounters.kernel_count, 2)
t[1].realize()
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
if __name__ == '__main__':
unittest.main()