mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Fold expand preceding reduce if the reduction is on the same axis as the expansion (#1134)
* fold expands that precede a reduce if the reduction is on the same axis as the expansion * add deterministic test for SIMPLIFY_SUM_RESHAPE_EXPAND_SUM optimization * add a test case to make sure we don't fold reduce-expand-reduce on different axes
This commit is contained in:
@@ -21,15 +21,43 @@ LAZYCACHE = getenv("LAZYCACHE", 1)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
|
||||
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS, SIMPLIFY_SUM_RESHAPE_EXPAND_SUM = OPT>=2, OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops
|
||||
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
||||
|
||||
def _simplify_sum_reshape_expand_sum(self:LazyBuffer, src: Any, prev_src: Any) -> Optional[LazyOp]:
|
||||
if prev_src.op.op == MovementOps.EXPAND:
|
||||
if src.op.op == ReduceOps.SUM:
|
||||
if src.shape == self.shape:
|
||||
dim_difference = [i for i, (a, b) in enumerate(zip(prev_src.shape, self.shape)) if a != b]
|
||||
# NOTE: we can probably also handle the case where more than one dimension is different with more thought
|
||||
if len(dim_difference) == 1:
|
||||
expansion_index = dim_difference[0]
|
||||
expansion_size = prev_src.shape[expansion_index]
|
||||
return LazyOp(BinaryOps.MUL, (src, LazyBuffer.const_like(src, expansion_size)))
|
||||
return None
|
||||
|
||||
# **** realize functions ****
|
||||
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
# NOTE: mypy doesn't know that if not src.realized, then src.op must be a LazyOp so we have to ignore a bunch of warnings
|
||||
src = self.op.src[0]
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype is BinaryOps and len(src.children) <= 1:
|
||||
src = src.op # type: ignore
|
||||
if not src.realized:
|
||||
# When a tensor is reduced, reshaped/expanded back and then reduced again along the same axis,
|
||||
# it's equivalent to performing the initial reduction and multiplying the result
|
||||
# by the size of the expanded dimension.
|
||||
if SIMPLIFY_SUM_RESHAPE_EXPAND_SUM and src.op.op == MovementOps.EXPAND: # type: ignore
|
||||
expanded = src.op.src[0] # type: ignore
|
||||
if expanded.op.op == MovementOps.RESHAPE: # type: ignore
|
||||
reshaped = expanded.op.src[0] # type: ignore
|
||||
simplified = _simplify_sum_reshape_expand_sum(self, reshaped, src)
|
||||
else:
|
||||
simplified = _simplify_sum_reshape_expand_sum(self, expanded, src)
|
||||
if simplified: return simplified
|
||||
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1:
|
||||
# If we did remove an expand above, we might stumble back into a case where the reduction is not necessary
|
||||
if src.shape == self.shape:
|
||||
return src.op # type: ignore
|
||||
src = src.op # type: ignore
|
||||
return LazyOp(self.op.op, (src,), self.op.arg)
|
||||
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
@@ -110,7 +138,9 @@ class LazyBuffer:
|
||||
if not self.realized:
|
||||
# get real ops first
|
||||
if self.optype is BinaryOps: self.op = _ast_binaryops(self)
|
||||
elif self.optype is ReduceOps: self.op = _ast_reduceops(self)
|
||||
elif self.optype is ReduceOps:
|
||||
self.op = _ast_reduceops(self)
|
||||
if self.op.op in BinaryOps: self.op = _ast_binaryops(self)
|
||||
elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self)
|
||||
# run the ast if we still have to, and log the op
|
||||
if not self.realized:
|
||||
|
||||
Reference in New Issue
Block a user