assert kernel count (#12205)

This commit is contained in:
qazal
2025-09-16 14:24:39 +03:00
committed by GitHub
parent e555748807
commit 122a50fe8c

View File

@@ -278,7 +278,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor.empty(10,10,10)
b = Tensor.empty(10,10,1)
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
with self.assertRaises(KernelCountException): check_schedule(c, 1)
check_schedule(c, 2)
def test_allow_push_permutes(self):
a = Tensor.randn(10,10,10).realize()
@@ -316,7 +316,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
check_schedule(d, 1, [c])
# failing in new lazy
def test_cache_binaryop_transpose(self):
@@ -324,7 +324,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
check_schedule(d, 1, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
@@ -558,7 +558,7 @@ class TestSchedule(unittest.TestCase):
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
out = c.sum() + d.sum()
with self.assertRaises(KernelCountException): check_schedule(out, 1)
check_schedule(out, 2)
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
@@ -569,6 +569,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(f, 2)
# failing in new lazy
@unittest.skip("always fusing elementwise")
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@@ -576,8 +577,8 @@ class TestSchedule(unittest.TestCase):
keep_me = a+b
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
with self.assertRaises(KernelCountException): check_schedule(d, 2)
with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
check_schedule(d, 2)
check_schedule(keep_me, 0, [d])
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
@@ -627,7 +628,8 @@ class TestSchedule(unittest.TestCase):
x = x.image_conv2d(w3, b3)
# NOOP, 3 convs, contiguous
with self.assertRaises(KernelCountException): check_schedule(x, 5)
#check_schedule(x, 5)
check_schedule(x, 8)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)