Fold expand preceding reduce if the reduction is on the same axis as the expansion (#1134)

* fold expands that precede a reduce if the reduction is on the same axis as the expansion

* add deterministic test for SIMPLIFY_SUM_RESHAPE_EXPAND_SUM optimization

* add a test case to make sure we don't fold reduce-expand-reduce on different axes
This commit is contained in:
Rayan Hatout
2023-07-06 21:41:05 +01:00
committed by GitHub
parent f109af3cbb
commit 9975f24452
2 changed files with 67 additions and 5 deletions

View File

@@ -21,15 +21,43 @@ LAZYCACHE = getenv("LAZYCACHE", 1)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS, SIMPLIFY_SUM_RESHAPE_EXPAND_SUM = OPT>=2, OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -> Optional[LazyOp]:
if prev_src.op.op == MovementOps.EXPAND:
if src.op.op == ReduceOps.SUM:
if src.shape == self.shape:
dim_difference = [i for i, (a, b) in enumerate(zip(prev_src.shape, self.shape)) if a != b]
# NOTE: we can probably also handle the case where more than one dimension is different with more thought
if len(dim_difference) == 1:
expansion_index = dim_difference[0]
expansion_size = prev_src.shape[expansion_index]
return LazyOp(BinaryOps.MUL, (src, LazyBuffer.const_like(src, expansion_size)))
return None
# **** realize functions ****
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
# NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings
src = self.op.src[0]
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype is BinaryOps and len(src.children) <= 1:
src = src.op # type: ignore
if not src.realized:
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
# it's equivalent to performing the initial reduction and multiplying the result
# by the size of the expanded dimension.
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore
expanded = src.op.src[0] # type: ignore
if expanded.op.op == MovementOps.RESHAPE: # type: ignore
reshaped = expanded.op.src[0] # type: ignore
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
else:
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
if simplified: return simplified
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
if src.shape == self.shape:
return src.op # type: ignore
src = src.op # type: ignore
return LazyOp(self.op.op, (src,), self.op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
@@ -110,7 +138,9 @@ class LazyBuffer:
if not self.realized:
# get real ops first
if self.optype is BinaryOps: self.op = _ast_binaryops(self)
elif self.optype is ReduceOps: self.op = _ast_reduceops(self)
elif self.optype is ReduceOps:
self.op = _ast_reduceops(self)
if self.op.op in BinaryOps: self.op = _ast_binaryops(self)
elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self)
# run the ast if we still have to, and log the op
if not self.realized: