assert kernel counts in unsupported fusions (#4643)

* replace with comments

* not relevant

* update comment

* custom exception maybe

* fix LoadOps.VIEW
This commit is contained in:
qazal
2024-05-19 01:14:37 +08:00
committed by GitHub
parent a5204fe89d
commit bf8f855838

View File

@@ -5,16 +5,17 @@
import unittest
import numpy as np
from typing import List, Optional, Union
from tinygrad.engine.realize import run_schedule
from tinygrad import nn, dtypes
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.graph import print_tree
from tinygrad.engine.schedule import create_schedule
from tinygrad import nn, dtypes
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
if isinstance(t, Tensor): t = [t]
seen = set()
@@ -30,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i, s in enumerate(sched):
print("kernel", i+1)
for op in s.ast: print_tree(op)
assert len(sched) == allowed, f"{len(sched)} != {allowed}"
if len(sched) != allowed: raise KernelCountException()
# test the (non loadops) ops linearize
for s in sched:
if s.ast[0].op in LoadOps: continue
@@ -102,12 +103,12 @@ class TestSchedule(unittest.TestCase):
c = a.sum(axis=0) + b
check_schedule(c, 1)
@unittest.skip("not pushing permutes through reduces")
# not pushing permutes through reduces
def test_reduce_permute_binop_fusion(self):
a = Tensor.empty(10,10,10)
b = Tensor.empty(10,10,1)
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
check_schedule(c, 1)
with self.assertRaises(KernelCountException): check_schedule(c, 1)
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
@@ -132,21 +133,21 @@ class TestSchedule(unittest.TestCase):
d = a+b
check_schedule(d, 0, [c])
@unittest.skip("failing in old lazy")
# failing in new lazy
def test_cache_binaryop_reshaped(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
check_schedule(d, 0, [c])
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
@unittest.skip("failing in new lazy")
# failing in new lazy
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
check_schedule(d, 0, [c])
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
@@ -310,17 +311,14 @@ class TestSchedule(unittest.TestCase):
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
check_schedule(c, 1)
@unittest.skip("failing in old lazy")
def test_fancy_reshape_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
out = c.sum() + d.sum()
check_schedule(out, 1)
with self.assertRaises(KernelCountException): check_schedule(out, 1)
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
@unittest.skip("not real world")
def test_children_dont_push(self):
a = Tensor.empty(10, 10, 1)
b = Tensor.empty(10, 10, 1)
@@ -329,7 +327,7 @@ class TestSchedule(unittest.TestCase):
f = d+e
check_schedule(f, 2)
@unittest.skip("failing in new lazy")
# failing in new lazy
def test_dont_fuse_binops_with_children(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@@ -337,8 +335,8 @@ class TestSchedule(unittest.TestCase):
keep_me = a+b
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
check_schedule(d, 2)
check_schedule(keep_me, 0, [d])
with self.assertRaises(KernelCountException): check_schedule(d, 2)
with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
#@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
@@ -374,9 +372,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(e, 2)
# this is the failing case in openpilot...it's very simple like this
@unittest.skip("failing in old lazy")
def test_image_conv_fusion(self):
from tinygrad.features.image import image_conv2d
w1 = Tensor.empty(16, 16, 1, 1)
b1 = Tensor.empty(16)
w2 = Tensor.empty(16, 16, 1, 1)
@@ -385,12 +381,12 @@ class TestSchedule(unittest.TestCase):
b3 = Tensor.empty(16)
x = Tensor.empty(1, 16, 32, 32)
x = base = image_conv2d(x, w1, b1)
x = image_conv2d(x, w2, b2) + base
x = image_conv2d(x, w3, b3)
x = base = x.image_conv2d(w1, b1)
x = x.image_conv2d(w2, b2) + base
x = x.image_conv2d(w3, b3)
# NOOP, 3 convs, contiguous
check_schedule(x, 5)
with self.assertRaises(KernelCountException): check_schedule(x, 5)
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
@@ -522,12 +518,12 @@ class TestSchedule(unittest.TestCase):
out = x + y
check_schedule(out, 2) # TODO: this should be 1
@unittest.skip("broken due to const folding and two contiguous are different kernels")
# broken due to const folding and two contiguous are different kernels
def test_const_no_recompute(self):
x = Tensor(2) + Tensor(2)
y = Tensor(2) + Tensor(2)
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
with self.assertRaises(KernelCountException): check_schedule(out, 2, filter_loadops=False)
def test_reduce_same_size(self):
a = Tensor.empty(4, 4)
@@ -805,7 +801,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
def test_pad_reduce_usafe(self):
def test_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()