From fac8633ba84b1bb90fc6c0e1f9115c4e57ee1257 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 14 Nov 2023 11:52:38 -0500 Subject: [PATCH] explicit opts for test_linearizer_failures (#2299) * explicit opts for test_linearizer_failures * typo * update the invalid check --- test/external/fuzz_linearizer.py | 68 +++++++++++++++++--------------- test/test_linearizer_failures.py | 35 +++++++++------- 2 files changed, 58 insertions(+), 45 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 7cf22b1e0f..d65ba01ba2 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -1,11 +1,11 @@ import random, traceback import numpy as np -from collections import Counter, defaultdict +from collections import Counter from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops from tinygrad.graph import print_tree -from tinygrad.helpers import ImageDType, prod, getenv +from tinygrad.helpers import getenv from tinygrad.ops import Device, Compiled, Interpreted from tinygrad.lazy import vars_from_ast @@ -19,6 +19,37 @@ class LB: self.dtype = dtype +def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): + if rawbufs is None: rawbufs = bufs_from_lin(lin) + if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)} + + # TODO: images needs required_optimization + if isinstance(device, Compiled): + try: + prg = device.to_program(lin) + except: + print(lin.ast) + traceback.print_exc() + print("COMPILE FAILED!!") + return "COMPILE_ERROR" + try: + prg.exec(rawbufs, var_vals, force_wait=True) + except: + print(lin.ast) + traceback.print_exc() + print("EXEC FAILED!!") + return "EXEC_ERROR" + else: + try: + device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]]) + except Exception as e: + print(lin.ast) + traceback.print_exc() + return str(type(e)) + + return "PASS" + + def fuzz_linearizer(lin: Linearizer): random.seed(42) np.random.seed(42) @@ -44,33 +75,11 @@ def fuzz_linearizer(lin: Linearizer): # get a new output buffer rawbufs[0] = type(rawbufs[0])(rawbufs[0].size, rawbufs[0].dtype) var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)} - - # TODO: images needs required_optimization - if isinstance(device, Compiled): - try: - prg = device.to_program(lin) - except: - print(lin.ast) - traceback.print_exc() - print("COMPILE FAILED!!") - return "COMPILE_ERROR" - try: - prg.exec(rawbufs, var_vals, force_wait=True) - except: - print(lin.ast) - traceback.print_exc() - print("EXEC FAILED!!") - return "EXEC_ERROR" - else: - try: - device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]]) - except Exception as e: - print(lin.ast) - traceback.print_exc() - return e + if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS": + print(f"{lin.applied_opts=}") + return msg result = rawbufs[0].toCPU() - if ground_truth is None: ground_truth = result else: @@ -79,11 +88,8 @@ def fuzz_linearizer(lin: Linearizer): except AssertionError: print(lin.ast) traceback.print_exc() + print(f"{lin.applied_opts=}") return "NOT_ALLCLOSE" - except Exception as e: - print(lin.ast) - traceback.print_exc() - return e return "PASS" diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 8d4004fc23..4e3ff0d2e5 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1,7 +1,8 @@ import unittest from tinygrad.codegen.linearizer import Linearizer -from tinygrad.ops import Device, Interpreted -from test.external.fuzz_linearizer import fuzz_linearizer +from tinygrad.features.search import Opt, OptOps +from tinygrad.ops import Device +from test.external.fuzz_linearizer import run_linearizer # stuff needed to unpack a kernel from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer @@ -11,11 +12,17 @@ from tinygrad.shape.view import View from tinygrad.shape.symbolic import Variable inf, nan = float('inf'), float('nan') -def helper_test_lin(lin, fixed_platforms): +def helper_test_lin(lin: Linearizer, opts, fixed_platforms): + for opt in opts: + try: + lin.apply_opt(opt) + except AssertionError: + # it's considered fixed if we invalidated the opts + return Device.DEFAULT in fixed_platforms if Device.DEFAULT in fixed_platforms: - return fuzz_linearizer(lin) == "PASS" + return run_linearizer(lin) == "PASS" else: - return fuzz_linearizer(lin) != "PASS" + return run_linearizer(lin) != "PASS" class TestLinearizerFailures(unittest.TestCase): @unittest.skip("this is currently failing") @@ -26,33 +33,33 @@ class TestLinearizerFailures(unittest.TestCase): # NOTE: test cases from fuzzer run. if you fixed something and it no longer fails, add platform to fixed_platforms list in helper_test_lin(). - @unittest.skipUnless(isinstance(Device[Device.DEFAULT], Interpreted), "fails on Interpreted") + @unittest.skipUnless(Device.DEFAULT in ["CPU", "TORCH"], "fails on these backends") def test_failure_2(self): ast = LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))))),), arg=(32, 2, 37, 9, 1, 1)) - lin = Linearizer(ast) - assert helper_test_lin(lin, fixed_platforms=[]) + opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] + assert helper_test_lin(Linearizer(ast), opts, fixed_platforms=[]) @unittest.skipUnless(Device.DEFAULT in ["METAL", "GPU", "LLVM"], "fails on these backends") def test_failure_3(self): ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)))),), arg=(32, 8, 16, 1)) - lin = Linearizer(ast) + opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} - assert helper_test_lin(lin, fixed_platforms=["LLVM"]) + assert helper_test_lin(Linearizer(ast), opts, fixed_platforms=["LLVM"]) @unittest.skipUnless(Device.DEFAULT in ["METAL", "LLVM"], "fails on these backends") def test_failure_4(self): ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 4, 1, 12, 2, 29), strides=(0, 0, 0, 2, 0, 216, 1, 8), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 1), (0, 11), (0, 2), (0, 27)), contiguous=False), View(shape=(1, 1, 1, 4, 22, 84), strides=(0, 0, 0, 696, 58, 1), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 12), (0, 58)), contiguous=False), View(shape=(1, 1, 1, 4, 2, 11, 3, 28), strides=(0, 0, 0, 1848, 924, 84, 28, 1), offset=0, mask=None, contiguous=True))))),), arg=(1, 1, 1, 4, 1, 11, 1, 28)) - lin = Linearizer(ast) + opts = [Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] # related to OptOps.NOLOCALS # IndexError: list index out of range - assert helper_test_lin(lin, fixed_platforms=["LLVM"]) + assert helper_test_lin(Linearizer(ast), opts, fixed_platforms=["LLVM"]) @unittest.skipUnless(Device.DEFAULT in ["CLANG", "LLVM"], "fails on these backends") def test_failure_5(self): ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.MEM, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(1, 1, 1, 1, 1, 1, 1, 1)) + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] # EXEC_ERROR, it has no global_size - lin = Linearizer(ast) - assert helper_test_lin(lin, fixed_platforms=["CLANG", "LLVM"]) + assert helper_test_lin(Linearizer(ast), opts, fixed_platforms=["CLANG", "LLVM"]) if __name__ == '__main__':