mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up test_setitem_slice (#14750)
moved to test_setitem_schedule, and use contiguous zeros as scheduler handles empty differently now
This commit is contained in:
@@ -68,28 +68,6 @@ class TestMemoryCount(unittest.TestCase):
|
||||
_, mem = get_stats(a.assign(a+a))
|
||||
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
||||
|
||||
def test_setitem_slice_const(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = 3
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
|
||||
|
||||
def test_setitem_slice_tensor(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
v = Tensor.empty(30, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = v
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
|
||||
|
||||
def test_setitem_full(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[:] = 3
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
|
||||
def test_copyout(self):
|
||||
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")
|
||||
|
||||
@@ -105,5 +105,33 @@ class TestSetitemInto(unittest.TestCase):
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
|
||||
|
||||
def test_setitem_slice_const(self):
|
||||
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = 3
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
|
||||
|
||||
def test_setitem_slice_tensor(self):
|
||||
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
||||
v = Tensor.zeros(30, dtype=dtypes.int32).contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = v
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
|
||||
|
||||
def test_setitem_full(self):
|
||||
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
t[:] = 3
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1306,7 +1306,7 @@ class Tensor(OpMixin):
|
||||
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
|
||||
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
|
||||
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
self.assign(self._getitem(indices, v))
|
||||
elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized
|
||||
|
||||
Reference in New Issue
Block a user