From 9975f244526dc636ccdfb94bff7840e65a319e8b Mon Sep 17 00:00:00 2001 From: Rayan Hatout Date: Thu, 6 Jul 2023 21:41:05 +0100 Subject: [PATCH] Fold expand preceding reduce if the reduction is on the same axis as the expansion (#1134) * fold expands that precede a reduce if the reduction is on the same axis as the expansion * add deterministic test for SIMPLIFY_SUM_RESHAPE_EXPAND_SUM optimization * add a test case to make sure we don't fold reduce-expand-reduce on different axes --- test/external/external_test_opt.py | 34 +++++++++++++++++++++++++- tinygrad/lazy.py | 38 ++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 89a682add9..88dbe2c1a5 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -1,5 +1,7 @@ #!/usr/bin/env python import os + +import torch if "OPT" not in os.environ: os.environ["OPT"] = "2" @@ -112,7 +114,6 @@ class TestOptBinOp(unittest.TestCase): #def test_no_binop_rerun_reduce(self): return self._test_no_binop_rerun(lambda a,b: (a*b).sum(), lambda a,b: (a*b).reshape(16, 16, 1).sum()) #def test_no_binop_rerun_reduce_alt(self): return self._test_no_binop_rerun(lambda a,b: a.sum(1)+b[0], lambda a,b: a.sum(1).reshape(1,16)+b[0]) -@unittest.skip("elementwise with >1 reduce inputs currently don't fuse") @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptReduceLoop(unittest.TestCase): def test_loop_left(self): @@ -359,5 +360,36 @@ class TestOpt(unittest.TestCase): cache_len = len(GlobalCounters.cache) assert cache_len == 1, "contiguous wasn't folded" + def _test_fold_expand_reduce_helper(self, n, m, axis, allowed): + b = torch.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis) + with CLCache(allowed=allowed): + a = Tensor.ones(n, m).sum(axis).reshape(n, 1).expand(n, m).sum(axis) + a.realize() + cache_len = len(GlobalCounters.cache) + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + return cache_len + + def test_expand_reduce_is_folded_on_same_axis(self): + for axis in [0, 1]: + for n in [4, 8, 16]: + b = torch.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) + with CLCache(allowed=2): + a = Tensor.ones(n, n).sum(axis).reshape(n, 1).expand(n, n).sum(axis) + a.realize() + cache_len = len(GlobalCounters.cache) + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + return cache_len + + def test_expand_reduce_is_not_folded_on_different_axes(self): + axis1, axis2 = 0, 1 + for n in [4, 8, 16]: + b = torch.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + with CLCache(allowed=3): + a = Tensor.ones(n, n).sum(axis1).reshape(n, 1).expand(n, n).sum(axis2) + a.realize() + cache_len = len(GlobalCounters.cache) + np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5) + return cache_len + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 40db7ba65e..7806bf44e9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -21,15 +21,43 @@ LAZYCACHE = getenv("LAZYCACHE", 1) # TODO: movement ops that only change shape are really nops. treat them as such REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 -MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops +MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS, SIMPLIFY_SUM_RESHAPE_EXPAND_SUM = OPT>=2, OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 +def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -> Optional[LazyOp]: + if prev_src.op.op == MovementOps.EXPAND: + if src.op.op == ReduceOps.SUM: + if src.shape == self.shape: + dim_difference = [i for i, (a, b) in enumerate(zip(prev_src.shape, self.shape)) if a != b] + # NOTE: we can probably also handle the case where more than one dimension is different with more thought + if len(dim_difference) == 1: + expansion_index = dim_difference[0] + expansion_size = prev_src.shape[expansion_index] + return LazyOp(BinaryOps.MUL, (src, LazyBuffer.const_like(src, expansion_size))) + return None + # **** realize functions **** def _ast_reduceops(self:LazyBuffer) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before + # NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings src = self.op.src[0] - if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype is BinaryOps and len(src.children) <= 1: - src = src.op # type: ignore + if not src.realized: + # When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis, + # it's equivalent to performing the initial reduction and multiplying the result + # by the size of the expanded dimension. + if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore + expanded = src.op.src[0] # type: ignore + if expanded.op.op == MovementOps.RESHAPE: # type: ignore + reshaped = expanded.op.src[0] # type: ignore + simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src) + else: + simplified = _simplify_sum_reshape_expand_sum(self, expanded, src) + if simplified: return simplified + if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: + # If we did remove an expand above, we might stumble back into a case where the reduction is not necessary + if src.shape == self.shape: + return src.op # type: ignore + src = src.op # type: ignore return LazyOp(self.op.op, (src,), self.op.arg) # this supports late merging an upstream Reduce op and even an Elementwise op above that @@ -110,7 +138,9 @@ class LazyBuffer: if not self.realized: # get real ops first if self.optype is BinaryOps: self.op = _ast_binaryops(self) - elif self.optype is ReduceOps: self.op = _ast_reduceops(self) + elif self.optype is ReduceOps: + self.op = _ast_reduceops(self) + if self.op.op in BinaryOps: self.op = _ast_binaryops(self) elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self) # run the ast if we still have to, and log the op if not self.realized: