mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
assert kernel counts in unsupported fusions (#4643)
* replace with comments * not relevant * update comment * custom exception maybe * fix LoadOps.VIEW
This commit is contained in:
@@ -5,16 +5,17 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
if isinstance(t, Tensor): t = [t]
|
||||
seen = set()
|
||||
@@ -30,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i, s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
for op in s.ast: print_tree(op)
|
||||
assert len(sched) == allowed, f"{len(sched)} != {allowed}"
|
||||
if len(sched) != allowed: raise KernelCountException()
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
@@ -102,12 +103,12 @@ class TestSchedule(unittest.TestCase):
|
||||
c = a.sum(axis=0) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
@unittest.skip("not pushing permutes through reduces")
|
||||
# not pushing permutes through reduces
|
||||
def test_reduce_permute_binop_fusion(self):
|
||||
a = Tensor.empty(10,10,10)
|
||||
b = Tensor.empty(10,10,1)
|
||||
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
|
||||
check_schedule(c, 1)
|
||||
with self.assertRaises(KernelCountException): check_schedule(c, 1)
|
||||
|
||||
def test_binop_early_reshape_reduce_fusion(self):
|
||||
a = Tensor.empty(100)
|
||||
@@ -132,21 +133,21 @@ class TestSchedule(unittest.TestCase):
|
||||
d = a+b
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
# failing in new lazy
|
||||
def test_cache_binaryop_reshaped(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
check_schedule(d, 0, [c])
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
|
||||
|
||||
@unittest.skip("failing in new lazy")
|
||||
# failing in new lazy
|
||||
def test_cache_binaryop_transpose(self):
|
||||
a = Tensor.empty(10,10)
|
||||
b = Tensor.empty(10,10)
|
||||
c = (a.T*b.T).T #.contiguous()
|
||||
d = a*b
|
||||
check_schedule(d, 0, [c])
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 0, [c])
|
||||
|
||||
def test_cache_two_reduceops(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -310,17 +311,14 @@ class TestSchedule(unittest.TestCase):
|
||||
c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b
|
||||
check_schedule(c, 1)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_fancy_reshape_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
out = c.sum() + d.sum()
|
||||
check_schedule(out, 1)
|
||||
with self.assertRaises(KernelCountException): check_schedule(out, 1)
|
||||
|
||||
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
|
||||
@unittest.skip("not real world")
|
||||
def test_children_dont_push(self):
|
||||
a = Tensor.empty(10, 10, 1)
|
||||
b = Tensor.empty(10, 10, 1)
|
||||
@@ -329,7 +327,7 @@ class TestSchedule(unittest.TestCase):
|
||||
f = d+e
|
||||
check_schedule(f, 2)
|
||||
|
||||
@unittest.skip("failing in new lazy")
|
||||
# failing in new lazy
|
||||
def test_dont_fuse_binops_with_children(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
@@ -337,8 +335,8 @@ class TestSchedule(unittest.TestCase):
|
||||
keep_me = a+b
|
||||
e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
||||
d = keep_me+c
|
||||
check_schedule(d, 2)
|
||||
check_schedule(keep_me, 0, [d])
|
||||
with self.assertRaises(KernelCountException): check_schedule(d, 2)
|
||||
with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d])
|
||||
|
||||
#@unittest.skip("failing in old lazy")
|
||||
def test_permute_breaks_fusion(self):
|
||||
@@ -374,9 +372,7 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(e, 2)
|
||||
|
||||
# this is the failing case in openpilot...it's very simple like this
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_image_conv_fusion(self):
|
||||
from tinygrad.features.image import image_conv2d
|
||||
w1 = Tensor.empty(16, 16, 1, 1)
|
||||
b1 = Tensor.empty(16)
|
||||
w2 = Tensor.empty(16, 16, 1, 1)
|
||||
@@ -385,12 +381,12 @@ class TestSchedule(unittest.TestCase):
|
||||
b3 = Tensor.empty(16)
|
||||
|
||||
x = Tensor.empty(1, 16, 32, 32)
|
||||
x = base = image_conv2d(x, w1, b1)
|
||||
x = image_conv2d(x, w2, b2) + base
|
||||
x = image_conv2d(x, w3, b3)
|
||||
x = base = x.image_conv2d(w1, b1)
|
||||
x = x.image_conv2d(w2, b2) + base
|
||||
x = x.image_conv2d(w3, b3)
|
||||
|
||||
# NOOP, 3 convs, contiguous
|
||||
check_schedule(x, 5)
|
||||
with self.assertRaises(KernelCountException): check_schedule(x, 5)
|
||||
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
@@ -522,12 +518,12 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x + y
|
||||
check_schedule(out, 2) # TODO: this should be 1
|
||||
|
||||
@unittest.skip("broken due to const folding and two contiguous are different kernels")
|
||||
# broken due to const folding and two contiguous are different kernels
|
||||
def test_const_no_recompute(self):
|
||||
x = Tensor(2) + Tensor(2)
|
||||
y = Tensor(2) + Tensor(2)
|
||||
out = x.contiguous() + y.contiguous()
|
||||
check_schedule(out, 2)
|
||||
with self.assertRaises(KernelCountException): check_schedule(out, 2, filter_loadops=False)
|
||||
|
||||
def test_reduce_same_size(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
@@ -805,7 +801,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
|
||||
|
||||
def test_pad_reduce_usafe(self):
|
||||
def test_pad_reduce_unsafe(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
|
||||
Reference in New Issue
Block a user