From 2cfca230b54ddaa6266f4f40deb76d5e1e3a3371 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:25:44 +0700 Subject: [PATCH] reduce collapse as a rule (#7464) * reduce collapse as a rule * better [pr] * cleaner --- tinygrad/codegen/uopgraph.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 1f007b0262..0376fa9059 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -232,6 +232,15 @@ def no_vectorized_wmma(wmma:UOp): wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas]) return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) +def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): + reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents) + if len(reduce_unparented) == 0: return None + new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented)) + ret = new_acc.assign(new_acc.alu(alu.arg, ret)) + if alu.arg is BinaryOps.ADD: + for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) + return ret + acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng") rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat) @@ -282,6 +291,9 @@ sym = symbolic_flat+PatternMatcher([ # indexing, with cast or where (acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), (acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse), + # parentless reduce + (acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), # ** self folding ** (UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST (UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP @@ -370,17 +382,10 @@ def do_expand(root:UOp): return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args) def do_reduce(ctx:List[int], root:UOp): - reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents) - ret = root.src[0] - if len(reduce_parented): - acc = UOp(UOps.DEFINE_ACC, root.dtype, - (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],)) - ctx[0] += 1 - ret = acc.assign(acc.alu(root.arg, ret)) - # for MAX, we can just ignore the unparented - if root.arg is BinaryOps.ADD: - for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) - return ret + acc = UOp(UOps.DEFINE_ACC, root.dtype, + (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],)) + ctx[0] += 1 + return acc.assign(acc.alu(root.arg, root.src[0])) def do_contract(con:UOp): ex = con.src[0] @@ -515,12 +520,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) + # convert REDUCE to DEFINE_ACC + ASSIGN (contextual) + sink = graph_rewrite(sink, just_reduce, ctx=[0]) + # initial symbolic + migrate indexing (remove this) + transcendental sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)) - # convert REDUCE to DEFINE_ACC + ASSIGN (contextual) - sink = graph_rewrite(sink, sym+just_reduce, ctx=[0]) - # expand sink = graph_rewrite(sink, sym+expander)