From 55e02cdd84500afd5a0ebd510f4f883e420693f1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 20 Jun 2024 16:10:08 +0300 Subject: [PATCH] generic gate folding (#5061) * add assert * fold truthy gates [run_process_replay] * fold falsy gates [run_process_replay] [no_assert] * redo asserts * check both barriers * spec start * spec end * assert srcs * make test_fold_gated_load_local better * [run_process_replay] [no_assert] --- test/test_uop_graph.py | 50 ++++++++++++++++++++++++++++++++++++++++ tinygrad/codegen/uops.py | 11 ++++++--- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 45e8ec1040..f6ff5da001 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -5,6 +5,15 @@ from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps from tinygrad.codegen.uops import UOpGraph, UOps, UOp class TestUOpGraph(unittest.TestCase): + # TODO: move to test.helpers + def assert_equiv_uops(self, uop1:UOp, uop2:UOp): + # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops + self.assertEqual(uop1.op, uop2.op) + self.assertEqual(uop1.dtype, uop2.dtype) + self.assertEqual(uop1.arg, uop2.arg) + self.assertEqual(len(uop1.src), len(uop2.src)) + for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2) + def test_add_constant_fold(self): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -72,5 +81,46 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.src[1].op, UOps.CONST) self.assertEqual(out.src[1].arg, 6) + def test_fold_gated_load(self): + glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) + glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False)) + glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False)) + idx = UOp.const(dtypes.int, 0) + ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2))) + ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3))) + uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld0+ld1))]) + ld0, ld1 = uops[-1].src[2].src + # ld0 becomes the invalid value + self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2)) + # the gate and invalid value are deleted from ld1 + self.assert_equiv_uops(ld1, UOp.load(glbl2, idx, dtype=dtypes.int)) + + def test_fold_gated_load_local(self): + glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) + smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1)) + lidx = UOp(UOps.SPECIAL, dtypes.int, (), (0, "lidx1", 16)) + st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) + barrier = UOp(UOps.BARRIER, None, (st, )) + ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier)) + ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier)) + uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld0+ld1))]) + ld0, ld1 = uops[-1].src[2].src + # ld0 becomes the invalid value + self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2)) + # the gate and invalid value are deleted from ld1 + self.assert_equiv_uops(ld1, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) + + def test_fold_gated_store(self): + glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) + idx0 = UOp.const(dtypes.int, 0) + idx1 = UOp.const(dtypes.int, 0) + val = UOp.const(dtypes.int, 42) + st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False))) + st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True))) + uops = UOpGraph([st0, st1]) + # only the second store happens + self.assertEqual(len(uops.uops), 4) + self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f388c0c707..3a19114683 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -256,10 +256,13 @@ constant_folder = PatternMatcher([ # cast NOOP (NOTE: it's str to deal with PtrDType) (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), # fold gated LOAD/STORE - (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), - (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(dtypes.int, 1), UOp.cvar("var"), UOp.var("barrier")), + (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var")), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), + (UOp.load(UOp.var("buf"), UOp.var("idx"), UOp.const(None, 1), UOp.cvar("var"), UOp.var("barrier")), lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)), - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(dtypes.int, 1)), UOp.store), + (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var")), lambda var: var), + (UOp.load(UOp.var(), UOp.var(), UOp.const(None, 0), UOp.cvar("var"), UOp.var()), lambda var: var), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.var("val"), UOp.const(None, 1)), UOp.store), + (UOp.store(UOp.var(), UOp.var(), UOp.var(), UOp.const(None, 0)), lambda: UOp(UOps.NOOP)), ]) # *** uop graph *** @@ -465,6 +468,8 @@ class UOpGraph: if uop is UOps.DEFINE_ACC: arg = arg[0] assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg + if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype is dtypes.bool + if uop is UOps.STORE and len(src) == 4: assert src[3].dtype is dtypes.bool if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"