mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
undo the view and make it a mask, this fuses the setitem with any pending compute too. one behavior change is that for target not backed by a buffer (const and arange), rangeify makes output contiguous under the hood. this is stricter better than raise and ask user to call contiguous, as that would no longer be fuse-able.
151 lines
5.5 KiB
Python
151 lines
5.5 KiB
Python
import unittest
|
|
from tinygrad import Tensor, dtypes, GlobalCounters
|
|
|
|
class TestSetitemInto(unittest.TestCase):
|
|
def test_setitem_into_unrealized(self):
|
|
GlobalCounters.reset()
|
|
t = Tensor.arange(4, dtype=dtypes.int32).reshape(2, 2)
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 16)
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [[0, 1], [5, 5]])
|
|
|
|
def test_setitem_into_unrealized_sliced_compute(self):
|
|
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
|
|
GlobalCounters.reset()
|
|
a = Tensor.arange(8, dtype=dtypes.int32).reshape(2, 4)
|
|
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [4, 6, 8, 10]
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
w[1] = 99
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
w.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
|
self.assertListEqual(w.tolist(), [4, 99, 8, 10])
|
|
|
|
def test_setitem_into_empty(self):
|
|
GlobalCounters.reset()
|
|
t = Tensor.empty(4, dtype=dtypes.int32)
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
# TODO: this can be just 4 if empty goes through is_realized setitem path
|
|
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(t[1].item(), 5)
|
|
|
|
def test_setitem_into_empty_alu(self):
|
|
GlobalCounters.reset()
|
|
t = Tensor.empty(4, dtype=dtypes.int32) + 1
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(t[1].item(), 5)
|
|
|
|
def test_setitem_into_tensor(self):
|
|
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize()
|
|
GlobalCounters.reset()
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t[1].realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [1, 5, 3, 4])
|
|
|
|
def test_setitem_into_tensor_alu(self):
|
|
t = Tensor([1, 2, 3, 4], dtype=dtypes.int32).realize() + 1
|
|
GlobalCounters.reset()
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t[1].realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4*(3*2+1)) # 3 elements had +1, 1 is assigned directly
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [2, 5, 4, 5])
|
|
|
|
def test_setitem_into_cont(self):
|
|
GlobalCounters.reset()
|
|
t = Tensor.ones(4, dtype=dtypes.int32)
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [1, 5, 1, 1])
|
|
|
|
def test_setitem_into_const_alu(self):
|
|
GlobalCounters.reset()
|
|
t = Tensor.ones(4, dtype=dtypes.int32) + 1
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 4*4)
|
|
t[1].realize()
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [2, 5, 2, 2])
|
|
|
|
def test_setitem_into_arange(self):
|
|
# NOTE: arange has no real buffer, but assigning to it is fine
|
|
GlobalCounters.reset()
|
|
t = Tensor.arange(4, dtype=dtypes.int32)
|
|
t[1] = 5
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
|
|
|
|
def test_setitem_slice_const(self):
|
|
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
|
GlobalCounters.reset()
|
|
t[20:50] = 3
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
|
|
|
|
def test_setitem_slice_tensor(self):
|
|
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
|
v = Tensor.zeros(30, dtype=dtypes.int32).contiguous().realize()
|
|
GlobalCounters.reset()
|
|
t[20:50] = v
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
|
|
|
|
def test_setitem_full(self):
|
|
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
|
GlobalCounters.reset()
|
|
t[:] = 3
|
|
self.assertEqual(GlobalCounters.kernel_count, 0)
|
|
t.realize()
|
|
self.assertEqual(GlobalCounters.kernel_count, 1)
|
|
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|