update test to reset and test kernel_count directly (#14832)

This commit is contained in:
chenyu
2026-02-17 11:48:46 -05:00
committed by GitHub
parent 9d4937ab5e
commit f147791105
3 changed files with 16 additions and 16 deletions

View File

@@ -32,11 +32,11 @@ class TestPickle(unittest.TestCase):
t_values = t.numpy()
del t # free buffers
print("** post pickle")
init = GlobalCounters.kernel_count
GlobalCounters.reset()
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t_values, t2.numpy())
# expect at most one COPY kernel
self.assertLessEqual(GlobalCounters.kernel_count-init, 1)
self.assertLessEqual(GlobalCounters.kernel_count, 1)
def test_pickle_realized_tensor_alt(self):
print("** init")

View File

@@ -6,9 +6,9 @@ class TestDataset(unittest.TestCase):
def test_dataset_is_realized(self):
X_train, _, _, _ = mnist()
X_train[0].contiguous().realize()
start = GlobalCounters.kernel_count
GlobalCounters.reset()
X_train[0].contiguous().realize()
self.assertEqual(GlobalCounters.kernel_count-start, 1)
self.assertEqual(GlobalCounters.kernel_count, 1)
if __name__ == '__main__':
unittest.main()

View File

@@ -263,16 +263,16 @@ class TestAssign(unittest.TestCase):
def test_assign_contiguous(self):
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1)
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 2
self.assertEqual(GlobalCounters.kernel_count, 2)
def test_assign_contiguous_permute(self):
b = Tensor.arange(16).reshape(4,4).contiguous().realize()
a = (Tensor.arange(16).reshape(4,4).contiguous().realize() + 1).permute((1,0))
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
b.assign(a.contiguous()).realize()
assert GlobalCounters.kernel_count - kc == 2
self.assertEqual(GlobalCounters.kernel_count, 2)
def test_permuted_assignment(self):
a = Tensor(np.arange(N*N, dtype=np.float32)).reshape(N,N)
@@ -345,9 +345,9 @@ class TestAssign(unittest.TestCase):
c.assign(r + c)
d.assign(r + d)
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
Tensor.realize(b, c, d)
assert GlobalCounters.kernel_count - kc == 1
self.assertEqual(GlobalCounters.kernel_count, 1)
np.testing.assert_allclose(b.numpy(), a.sum(1).numpy()+1)
np.testing.assert_allclose(c.numpy(), a.sum(1).numpy()+2)
np.testing.assert_allclose(d.numpy(), a.sum(1).numpy()+3)
@@ -389,13 +389,13 @@ class TestAssign(unittest.TestCase):
b = Tensor.arange(32 * 32).reshape(32, 32).realize()
c = Tensor.arange(32 * 32).reshape(32, 32).realize()
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)
c.assign(r + b_perm.contiguous())
Tensor.realize(b, c)
assert GlobalCounters.kernel_count - kc == 2
self.assertEqual(GlobalCounters.kernel_count, 2)
np.testing.assert_equal(b.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32))
np.testing.assert_equal(c.numpy(), a.numpy().sum(1) + np.arange(32 * 32).reshape(32, 32).transpose(1, 0))
@@ -403,9 +403,9 @@ class TestAssign(unittest.TestCase):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
a.assign(a + b)
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
a.realize()
assert GlobalCounters.kernel_count - kc == 1
self.assertEqual(GlobalCounters.kernel_count, 1)
np.testing.assert_equal(a.numpy(), np.ones((4, 4))+np.pad(np.ones((4, 4))[:, 0:2], ((0, 0), (0, 2)), constant_values=2))
def test_permuted_assignment_masked_view_not_contiguous(self):
@@ -442,9 +442,9 @@ class TestAssign(unittest.TestCase):
a = Tensor.arange(100).float().contiguous().realize()
expected = np.arange(100, dtype=np.float32)
expected[0:10] = expected[50:60].copy()
kc = GlobalCounters.kernel_count
GlobalCounters.reset()
a[0:10].assign(a[50:60]).realize()
assert GlobalCounters.kernel_count - kc == 2, "currently conservative, forces contiguous"
self.assertEqual(GlobalCounters.kernel_count, 2) # currently conservative, forces contiguous
np.testing.assert_allclose(a.numpy(), expected)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")