merge_double_reduce without asserts [pr] (#9650)

This commit is contained in:
qazal
2025-03-31 19:17:05 +08:00
committed by GitHub
parent 1444069c09
commit 5171b098e5

View File

@@ -300,11 +300,6 @@ def elementwise_view_right(root:UOp):
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
def merge_double_reduce(root:UOp, first_reduce:UOp):
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
# push VIEW to children
view_right = merge_views+PatternMatcher([
# push a non contiguous ShapeTracker through reduceop
@@ -313,8 +308,9 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
])
# **** unbind variables