mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update test to reset and test kernel_count directly (#14832)
This commit is contained in:
@@ -32,11 +32,11 @@ class TestPickle(unittest.TestCase):
|
||||
t_values = t.numpy()
|
||||
del t # free buffers
|
||||
print("** post pickle")
|
||||
init = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t_values, t2.numpy())
|
||||
# expect at most one COPY kernel
|
||||
self.assertLessEqual(GlobalCounters.kernel_count-init, 1)
|
||||
self.assertLessEqual(GlobalCounters.kernel_count, 1)
|
||||
|
||||
def test_pickle_realized_tensor_alt(self):
|
||||
print("** init")
|
||||
|
||||
@@ -6,9 +6,9 @@ class TestDataset(unittest.TestCase):
|
||||
def test_dataset_is_realized(self):
|
||||
X_train, _, _, _ = mnist()
|
||||
X_train[0].contiguous().realize()
|
||||
start = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
X_train[0].contiguous().realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count-start, 1)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -263,16 +263,16 @@ class TestAssign(unittest.TestCase):
|
||||
def test_assign_contiguous(self):
|
||||
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
||||
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
|
||||
def test_assign_contiguous_permute(self):
|
||||
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
|
||||
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1).permute((1,0))
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
b.assign(a.contiguous()).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
|
||||
def test_permuted_assignment(self):
|
||||
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
|
||||
@@ -345,9 +345,9 @@ class TestAssign(unittest.TestCase):
|
||||
c.assign(r + c)
|
||||
d.assign(r + d)
|
||||
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
Tensor.realize(b, c, d)
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1)
|
||||
np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2)
|
||||
np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3)
|
||||
@@ -389,13 +389,13 @@ class TestAssign(unittest.TestCase):
|
||||
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
|
||||
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
c.assign(r + b_perm.contiguous())
|
||||
Tensor.realize(b, c)
|
||||
assert GlobalCounters.kernel_count - kc == 2
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
|
||||
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
|
||||
|
||||
@@ -403,9 +403,9 @@ class TestAssign(unittest.TestCase):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
|
||||
a.assign(a + b)
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
a.realize()
|
||||
assert GlobalCounters.kernel_count - kc == 1
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2))
|
||||
|
||||
def test_permuted_assignment_masked_view_not_contiguous(self):
|
||||
@@ -442,9 +442,9 @@ class TestAssign(unittest.TestCase):
|
||||
a = Tensor.arange(100).float().contiguous().realize()
|
||||
expected = np.arange(100, dtype=np.float32)
|
||||
expected[0:10] = expected[50:60].copy()
|
||||
kc = GlobalCounters.kernel_count
|
||||
GlobalCounters.reset()
|
||||
a[0:10].assign(a[50:60]).realize()
|
||||
assert GlobalCounters.kernel_count - kc == 2, "currently conservative, forces contiguous"
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous
|
||||
np.testing.assert_allclose(a.numpy(), expected)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
|
||||
Reference in New Issue
Block a user