diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e2d45d7446..666a5d940b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -300,11 +300,6 @@ def elementwise_view_right(root:UOp): # reshape to match downstream shapes return root.replace(src=tuple(new_src)).reshape(root.shape) -def merge_double_reduce(root:UOp, first_reduce:UOp): - assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" - return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) - # push VIEW to children view_right = merge_views+PatternMatcher([ # push a non contiguous ShapeTracker through reduceop @@ -313,8 +308,9 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right), # apply view after elementwise ops (UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right), - # double reduce op collapses to a single reduce op - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), + # merge axes for double reduce (invert of SPLIT_REDUCEOP=1) + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"), + lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None), ]) # **** unbind variables