From 3c62315aa8512bb830d6f4a27c57359c9c7e0944 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 26 Oct 2024 13:20:32 +0300 Subject: [PATCH 1/4] add resnet pf (#7310) * add resnet pf * all platforms --- test/test_linearizer_failures.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index ece7d43787..e9b3d9a524 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1317,5 +1317,21 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) + @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") + def test_failure_55(self): + ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( + UOp(UOps.STORE, dtypes.void, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(UOps.CAST, dtypes.half, arg=None, src=( + UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(UOps.CAST, dtypes.float, arg=None, src=( + UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(UOps.LOAD, dtypes.half, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV", "METAL", "AMD", "HIP"]) + if __name__ == '__main__': unittest.main() From 293714610a24bad39188907129846ee63ab00b35 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 26 Oct 2024 13:59:45 +0300 Subject: [PATCH 2/4] capture beam log runtime errors (#7311) --- test/external/fuzz_linearizer.py | 2 +- tinygrad/engine/search.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 502fc78951..3723f70bc8 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -277,7 +277,7 @@ if __name__ == "__main__": elif args.beamreplay is not None: print(f"loading BEAM replay from file '{args.beamreplay}'") with open(args.beamreplay, 'r') as file: fdata = file.readlines() - ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata] + ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata if not x.startswith("#")], [x.split(' :: ')[1] for x in fdata if not x.startswith("#")] # dedup ast_strs and opts_list dct = defaultdict(list) diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 3d5ad17770..8bd39d2d01 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -156,7 +156,10 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n") seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) - except RuntimeError: continue # for runtime issues + except RuntimeError as e: + if len(CAPTURE_BEAM) > 0: + with open(CAPTURE_BEAM, 'a') as f: f.write("# Upper ast finished with an error:" + str(e).replace('\n',' ')+ "\n") + continue # for runtime issues timed_lins.append((acted_lins[i], min(tms))) if BEAM_DEBUG > 1: print(f"{time.perf_counter() - st:7.2f}s: {i:5d} {len(cast(List, p.uops)):5d} uops {compile_et*1e6:12.2f} us compile/{timed_lins[-1][1]*1e6:12.2f} us run {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}") # noqa: E501 elif DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 From 0a4d01f6d480de7c0da37ae6196a742c84584193 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 26 Oct 2024 12:42:48 -0400 Subject: [PATCH 3/4] disable simplify_valid (#7312) fixed test_failure_55. will reenable it later after fixing the bug --- test/test_linearizer_failures.py | 2 +- test/unit/test_simplify_valid_idx.py | 1 + tinygrad/codegen/uopgraph.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index e9b3d9a524..c1c636f120 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1331,7 +1331,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV", "METAL", "AMD", "HIP"]) + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 9824766b09..739e88b5b1 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -64,6 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase): "0", "(((lidx0+(gidx0*4))<19)!=True)") + @unittest.expectedFailure # FIXME def test_simplify_within_valid(self): ridx0 = Range(0, 4) ridx1 = Range(1, 4) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4b15265e0e..92c2ad851a 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat -from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid +from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -483,8 +483,8 @@ reducer = PatternMatcher([ (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), - # simplify valid - (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + # # simplify valid + # (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification (UPat(UOps.LOAD, name="load"), simplify_image_load), # buffer load valid idx simplification From d66fe7a66fb84ae7cb594ae3394e18470ff26d2d Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 26 Oct 2024 14:21:12 -0400 Subject: [PATCH 4/4] fix simplify_valid (#7313) the simplex should compare with valid bound, not its vmin --- test/unit/test_simplify_valid_idx.py | 12 +++++++++--- tinygrad/codegen/uopgraph.py | 6 +++--- tinygrad/ops.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 739e88b5b1..21867f9e6c 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -3,7 +3,7 @@ from typing import Tuple from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, UOps, simplify_valid def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(UOps.LOAD, dtypes.float, ( @@ -64,8 +64,7 @@ class TestValidIdxSimplification(unittest.TestCase): "0", "(((lidx0+(gidx0*4))<19)!=True)") - @unittest.expectedFailure # FIXME - def test_simplify_within_valid(self): + def test_simplify_within_valid1(self): ridx0 = Range(0, 4) ridx1 = Range(1, 4) ridx2 = Range(2, 4) @@ -77,6 +76,13 @@ class TestValidIdxSimplification(unittest.TestCase): "(((ridx0+ridx1)+ridx2)+ridx3)", "((((ridx0*3)+ridx1)<8)&((((ridx2*3)+ridx3)%4)<2))") + def test_simplify_within_valid2(self): + gidx0 = Special("gidx0", 56) + ridx0 = Range(0, 3) + alu0 = gidx0+ridx0 + valid = alu0.lt(57) & alu0.ge(1) + self.assertIsNone(simplify_valid(valid)) + class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): load = full_graph_rewrite(load.sink()).src[0] diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 92c2ad851a..4b15265e0e 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -4,7 +4,7 @@ import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat -from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing +from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, identity_element, uop_given_valid, parse_valid, is_increasing, simplify_valid from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -483,8 +483,8 @@ reducer = PatternMatcher([ (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), - # # simplify valid - # (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + # simplify valid + (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification (UPat(UOps.LOAD, name="load"), simplify_image_load), # buffer load valid idx simplification diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4967ef0e64..55a3cf5f1b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -944,7 +944,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) # try checking the whole clause