From bf8f85583823e2426b2db3e7a4e8a80dd386712b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 19 May 2024 01:14:37 +0800 Subject: [PATCH] assert kernel counts in unsupported fusions (#4643) * replace with comments * not relevant * update comment * custom exception maybe * fix LoadOps.VIEW --- test/test_schedule.py | 46 ++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index a11d05df55..b11ced5935 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -5,16 +5,17 @@ import unittest import numpy as np from typing import List, Optional, Union -from tinygrad.engine.realize import run_schedule +from tinygrad import nn, dtypes from tinygrad.tensor import Tensor from tinygrad.ops import BinaryOps, LoadOps, ReduceOps from tinygrad.helpers import DEBUG, flatten from tinygrad.codegen.linearizer import Linearizer from tinygrad.engine.graph import print_tree from tinygrad.engine.schedule import create_schedule -from tinygrad import nn, dtypes +from tinygrad.engine.realize import run_schedule from test.helpers import is_dtype_supported +class KernelCountException(Exception): pass def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): if isinstance(t, Tensor): t = [t] seen = set() @@ -30,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt for i, s in enumerate(sched): print("kernel", i+1) for op in s.ast: print_tree(op) - assert len(sched) == allowed, f"{len(sched)} != {allowed}" + if len(sched) != allowed: raise KernelCountException() # test the (non loadops) ops linearize for s in sched: if s.ast[0].op in LoadOps: continue @@ -102,12 +103,12 @@ class TestSchedule(unittest.TestCase): c = a.sum(axis=0) + b check_schedule(c, 1) - @unittest.skip("not pushing permutes through reduces") + # not pushing permutes through reduces def test_reduce_permute_binop_fusion(self): a = Tensor.empty(10,10,10) b = Tensor.empty(10,10,1) c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b - check_schedule(c, 1) + with self.assertRaises(KernelCountException): check_schedule(c, 1) def test_binop_early_reshape_reduce_fusion(self): a = Tensor.empty(100) @@ -132,21 +133,21 @@ class TestSchedule(unittest.TestCase): d = a+b check_schedule(d, 0, [c]) - @unittest.skip("failing in old lazy") + # failing in new lazy def test_cache_binaryop_reshaped(self): a = Tensor.empty(10) b = Tensor.empty(10) c = a+b d = a.reshape(10,1)+b.reshape(10,1) - check_schedule(d, 0, [c]) + with self.assertRaises(KernelCountException): check_schedule(d, 0, [c]) - @unittest.skip("failing in new lazy") + # failing in new lazy def test_cache_binaryop_transpose(self): a = Tensor.empty(10,10) b = Tensor.empty(10,10) c = (a.T*b.T).T #.contiguous() d = a*b - check_schedule(d, 0, [c]) + with self.assertRaises(KernelCountException): check_schedule(d, 0, [c]) def test_cache_two_reduceops(self): a = Tensor.empty(10) @@ -310,17 +311,14 @@ class TestSchedule(unittest.TestCase): c = a.sum((0,1)).cast(dtypes.float16).permute(1,0).reshape(4,4,1).permute(1,0,2).reshape(16) + b check_schedule(c, 1) - @unittest.skip("failing in old lazy") def test_fancy_reshape_fusion(self): a = Tensor.empty(10) b = Tensor.empty(10) c = a+b d = a.reshape(10,1)+b.reshape(10,1) out = c.sum() + d.sum() - check_schedule(out, 1) + with self.assertRaises(KernelCountException): check_schedule(out, 1) - # NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first - @unittest.skip("not real world") def test_children_dont_push(self): a = Tensor.empty(10, 10, 1) b = Tensor.empty(10, 10, 1) @@ -329,7 +327,7 @@ class TestSchedule(unittest.TestCase): f = d+e check_schedule(f, 2) - @unittest.skip("failing in new lazy") + # failing in new lazy def test_dont_fuse_binops_with_children(self): a = Tensor.empty(10) b = Tensor.empty(10) @@ -337,8 +335,8 @@ class TestSchedule(unittest.TestCase): keep_me = a+b e = keep_me.sum() # noqa: F841 give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse) d = keep_me+c - check_schedule(d, 2) - check_schedule(keep_me, 0, [d]) + with self.assertRaises(KernelCountException): check_schedule(d, 2) + with self.assertRaises(KernelCountException): check_schedule(keep_me, 0, [d]) #@unittest.skip("failing in old lazy") def test_permute_breaks_fusion(self): @@ -374,9 +372,7 @@ class TestSchedule(unittest.TestCase): check_schedule(e, 2) # this is the failing case in openpilot...it's very simple like this - @unittest.skip("failing in old lazy") def test_image_conv_fusion(self): - from tinygrad.features.image import image_conv2d w1 = Tensor.empty(16, 16, 1, 1) b1 = Tensor.empty(16) w2 = Tensor.empty(16, 16, 1, 1) @@ -385,12 +381,12 @@ class TestSchedule(unittest.TestCase): b3 = Tensor.empty(16) x = Tensor.empty(1, 16, 32, 32) - x = base = image_conv2d(x, w1, b1) - x = image_conv2d(x, w2, b2) + base - x = image_conv2d(x, w3, b3) + x = base = x.image_conv2d(w1, b1) + x = x.image_conv2d(w2, b2) + base + x = x.image_conv2d(w3, b3) # NOOP, 3 convs, contiguous - check_schedule(x, 5) + with self.assertRaises(KernelCountException): check_schedule(x, 5) def test_image_conv_fusion_minimal(self): b1 = Tensor.empty(16) @@ -522,12 +518,12 @@ class TestSchedule(unittest.TestCase): out = x + y check_schedule(out, 2) # TODO: this should be 1 - @unittest.skip("broken due to const folding and two contiguous are different kernels") + # broken due to const folding and two contiguous are different kernels def test_const_no_recompute(self): x = Tensor(2) + Tensor(2) y = Tensor(2) + Tensor(2) out = x.contiguous() + y.contiguous() - check_schedule(out, 2) + with self.assertRaises(KernelCountException): check_schedule(out, 2, filter_loadops=False) def test_reduce_same_size(self): a = Tensor.empty(4, 4) @@ -805,7 +801,7 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 1)) np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum()) - def test_pad_reduce_usafe(self): + def test_pad_reduce_unsafe(self): Tensor.manual_seed(0) a = Tensor.rand(3, 4, 5).realize() out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()