diff --git a/test/backend/test_pickle.py b/test/backend/test_pickle.py index 02e9a7da63..dbf9dae402 100644 --- a/test/backend/test_pickle.py +++ b/test/backend/test_pickle.py @@ -32,11 +32,11 @@ class TestPickle(unittest.TestCase): t_values = t.numpy() del t # free buffers print("** post pickle") - init = GlobalCounters.kernel_count + GlobalCounters.reset() t2:Tensor = pickle.loads(st) np.testing.assert_equal(t_values, t2.numpy()) # expect at most one COPY kernel - self.assertLessEqual(GlobalCounters.kernel_count-init, 1) + self.assertLessEqual(GlobalCounters.kernel_count, 1) def test_pickle_realized_tensor_alt(self): print("** init") diff --git a/test/null/test_mnist_dataset.py b/test/null/test_mnist_dataset.py index 9db9a9e37d..88b44fca2c 100644 --- a/test/null/test_mnist_dataset.py +++ b/test/null/test_mnist_dataset.py @@ -6,9 +6,9 @@ class TestDataset(unittest.TestCase): def test_dataset_is_realized(self): X_train, _, _, _ = mnist() X_train[0].contiguous().realize() - start = GlobalCounters.kernel_count + GlobalCounters.reset() X_train[0].contiguous().realize() - self.assertEqual(GlobalCounters.kernel_count-start, 1) + self.assertEqual(GlobalCounters.kernel_count, 1) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index ffcf970652..eb1f97a6cf 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -263,16 +263,16 @@ class TestAssign(unittest.TestCase): def test_assign_contiguous(self): b = Tensor.arange(16).reshape(4,4).contiguous().realize() a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1) - kc = GlobalCounters.kernel_count + GlobalCounters.reset() b.assign(a.contiguous()).realize() - assert GlobalCounters.kernel_count - kc == 2 + self.assertEqual(GlobalCounters.kernel_count, 2) def test_assign_contiguous_permute(self): b = Tensor.arange(16).reshape(4,4).contiguous().realize() a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1).permute((1,0)) - kc = GlobalCounters.kernel_count + GlobalCounters.reset() b.assign(a.contiguous()).realize() - assert GlobalCounters.kernel_count - kc == 2 + self.assertEqual(GlobalCounters.kernel_count, 2) def test_permuted_assignment(self): a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N) @@ -345,9 +345,9 @@ class TestAssign(unittest.TestCase): c.assign(r + c) d.assign(r + d) - kc = GlobalCounters.kernel_count + GlobalCounters.reset() Tensor.realize(b, c, d) - assert GlobalCounters.kernel_count - kc == 1 + self.assertEqual(GlobalCounters.kernel_count, 1) np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1) np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2) np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3) @@ -389,13 +389,13 @@ class TestAssign(unittest.TestCase): b = Tensor.arange(32 * 32).reshape(32, 32).realize() c = Tensor.arange(32 * 32).reshape(32, 32).realize() - kc = GlobalCounters.kernel_count + GlobalCounters.reset() r = a.sum(axis=1) b_perm = b.permute(1, 0) b.assign(r + b) c.assign(r + b_perm.contiguous()) Tensor.realize(b, c) - assert GlobalCounters.kernel_count - kc == 2 + self.assertEqual(GlobalCounters.kernel_count, 2) np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32)) np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0)) @@ -403,9 +403,9 @@ class TestAssign(unittest.TestCase): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2) a.assign(a + b) - kc = GlobalCounters.kernel_count + GlobalCounters.reset() a.realize() - assert GlobalCounters.kernel_count - kc == 1 + self.assertEqual(GlobalCounters.kernel_count, 1) np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2)) def test_permuted_assignment_masked_view_not_contiguous(self): @@ -442,9 +442,9 @@ class TestAssign(unittest.TestCase): a = Tensor.arange(100).float().contiguous().realize() expected = np.arange(100, dtype=np.float32) expected[0:10] = expected[50:60].copy() - kc = GlobalCounters.kernel_count + GlobalCounters.reset() a[0:10].assign(a[50:60]).realize() - assert GlobalCounters.kernel_count - kc == 2, "currently conservative, forces contiguous" + self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous np.testing.assert_allclose(a.numpy(), expected) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")