diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 502fc78951..3723f70bc8 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -277,7 +277,7 @@ if __name__ == "__main__": elif args.beamreplay is not None: print(f"loading BEAM replay from file '{args.beamreplay}'") with open(args.beamreplay, 'r') as file: fdata = file.readlines() - ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata] + ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata if not x.startswith("#")], [x.split(' :: ')[1] for x in fdata if not x.startswith("#")] # dedup ast_strs and opts_list dct = defaultdict(list) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index ece7d43787..c1c636f120 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1317,5 +1317,21 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) + @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") + def test_failure_55(self): + ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( + UOp(UOps.STORE, dtypes.void, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(UOps.CAST, dtypes.half, arg=None, src=( + UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(UOps.CAST, dtypes.float, arg=None, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(UOps.LOAD, dtypes.half, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 9824766b09..21867f9e6c 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -3,7 +3,7 @@ from typing import Tuple from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, UOps, simplify_valid def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(UOps.LOAD, dtypes.float, ( @@ -64,7 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase): "0", "(((lidx0+(gidx0*4))<19)!=True)") - def test_simplify_within_valid(self): + def test_simplify_within_valid1(self): ridx0 = Range(0, 4) ridx1 = Range(1, 4) ridx2 = Range(2, 4) @@ -76,6 +76,13 @@ class TestValidIdxSimplification(unittest.TestCase): "(((ridx0+ridx1)+ridx2)+ridx3)", "((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))") + def test_simplify_within_valid2(self): + gidx0 = Special("gidx0", 56) + ridx0 = Range(0, 3) + alu0 = gidx0+ridx0 + valid = alu0.lt(57) & alu0.ge(1) + self.assertIsNone(simplify_valid(valid)) + class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): load = full_graph_rewrite(load.sink()).src[0] diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 3d5ad17770..8bd39d2d01 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -156,7 +156,10 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n") seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) - except RuntimeError: continue # for runtime issues + except RuntimeError as e: + if len(CAPTURE_BEAM) > 0: + with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n") + continue # for runtime issues timed_lins.append((acted_lins[i], min(tms))) if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501 elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4967ef0e64..55a3cf5f1b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -944,7 +944,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) # try checking the whole clause