mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update test_uops_stats for setitem (#14710)
realize both full tensor and the slice should not add to global_mem
This commit is contained in:
@@ -72,6 +72,7 @@ class TestMemoryCount(unittest.TestCase):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = 3
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
|
||||
|
||||
def test_setitem_slice_tensor(self):
|
||||
@@ -79,12 +80,14 @@ class TestMemoryCount(unittest.TestCase):
|
||||
v = Tensor.empty(30, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = v
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
|
||||
|
||||
def test_setitem_full(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[:] = 3
|
||||
t.realize()
|
||||
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
|
||||
|
||||
Reference in New Issue
Block a user