mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
merge_double_reduce without asserts [pr] (#9650)
This commit is contained in:
@@ -300,11 +300,6 @@ def elementwise_view_right(root:UOp):
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp):
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a non contiguous ShapeTracker through reduceop
|
||||
@@ -313,8 +308,9 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="root"), elementwise_view_right),
|
||||
# double reduce op collapses to a single reduce op
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] == r2.arg[0] else None),
|
||||
])
|
||||
|
||||
# **** unbind variables
|
||||
|
||||
Reference in New Issue
Block a user