reduce collapse as a rule (#7464)

* reduce collapse as a rule

* better [pr]

* cleaner
This commit is contained in:
George Hotz
2024-11-01 15:25:44 +07:00
committed by GitHub
parent 4f6cf1f8cc
commit 2cfca230b5

View File

@@ -232,6 +232,15 @@ def no_vectorized_wmma(wmma:UOp):
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
if len(reduce_unparented) == 0: return None
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
ret = new_acc.assign(new_acc.alu(alu.arg, ret))
if alu.arg is BinaryOps.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng")
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
@@ -282,6 +291,9 @@ sym = symbolic_flat+PatternMatcher([
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# parentless reduce
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
# ** self folding **
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
@@ -370,17 +382,10 @@ def do_expand(root:UOp):
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
def do_reduce(ctx:List[int], root:UOp):
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents)
ret = root.src[0]
if len(reduce_parented):
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],))
ctx[0] += 1
ret = acc.assign(acc.alu(root.arg, ret))
# for MAX, we can just ignore the unparented
if root.arg is BinaryOps.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
acc = UOp(UOps.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
ctx[0] += 1
return acc.assign(acc.alu(root.arg, root.src[0]))
def do_contract(con:UOp):
ex = con.src[0]
@@ -515,12 +520,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, just_reduce, ctx=[0])
# initial symbolic + migrate indexing (remove this) + transcendental
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
# expand
sink = graph_rewrite(sink, sym+expander)