mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
assert kernel count (#12205)
This commit is contained in:
@@ -278,7 +278,7 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.empty(10,10,10)
|
||||
b = Tensor.empty(10,10,1)
|
||||
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
|
||||
with self.assertRaises(KernelCountException): check_schedule(c, 1)
|
||||
check_schedule(c, 2)
|
||||
|
||||
def test_allow_push_permutes(self):
|
||||
a = Tensor.randn(10,10,10).realize()
|
||||
@@ -316,7 +316,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
|
||||
check_schedule(d, 1, [c])
|
||||
|
||||
# failing in new lazy
|
||||
def test_cache_binaryop_transpose(self):
|
||||
@@ -324,7 +324,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty(10,10)
|
||||
c = (a.T*b.T).T #.contiguous()
|
||||
d = a*b
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
|
||||
check_schedule(d, 1, [c])
|
||||
|
||||
def test_cache_two_reduceops(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -558,7 +558,7 @@ class TestSchedule(unittest.TestCase):
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
out = c.sum() + d.sum()
|
||||
with self.assertRaises(KernelCountException): check_schedule(out, 1)
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_children_dont_push(self):
|
||||
a = Tensor.empty(10, 10, 1)
|
||||
@@ -569,6 +569,7 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(f, 2)
|
||||
|
||||
# failing in new lazy
|
||||
@unittest.skip("always fusing elementwise")
|
||||
def test_dont_fuse_binops_with_children(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
@@ -576,8 +577,8 @@ class TestSchedule(unittest.TestCase):
|
||||
keep_me = a+b
|
||||
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
||||
d = keep_me+c
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 2)
|
||||
with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
|
||||
check_schedule(d, 2)
|
||||
check_schedule(keep_me, 0, [d])
|
||||
|
||||
#@unittest.skip("failing in old lazy")
|
||||
def test_permute_breaks_fusion(self):
|
||||
@@ -627,7 +628,8 @@ class TestSchedule(unittest.TestCase):
|
||||
x = x.image_conv2d(w3, b3)
|
||||
|
||||
# NOOP, 3 convs, contiguous
|
||||
with self.assertRaises(KernelCountException): check_schedule(x, 5)
|
||||
#check_schedule(x, 5)
|
||||
check_schedule(x, 8)
|
||||
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
|
||||
Reference in New Issue
Block a user