mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
reduce collapse as a rule (#7464)
* reduce collapse as a rule * better [pr] * cleaner
This commit is contained in:
@@ -232,6 +232,15 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
||||
ret = new_acc.assign(new_acc.alu(alu.arg, ret))
|
||||
if alu.arg is BinaryOps.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng")
|
||||
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
||||
|
||||
@@ -282,6 +291,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# parentless reduce
|
||||
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
@@ -370,17 +382,10 @@ def do_expand(root:UOp):
|
||||
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents)
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
ret = acc.assign(acc.alu(root.arg, ret))
|
||||
# for MAX, we can just ignore the unparented
|
||||
if root.arg is BinaryOps.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
return acc.assign(acc.alu(root.arg, root.src[0]))
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
@@ -515,12 +520,12 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, just_reduce, ctx=[0])
|
||||
|
||||
# initial symbolic + migrate indexing (remove this) + transcendental
|
||||
sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2))
|
||||
|
||||
# convert REDUCE to DEFINE_ACC + ASSIGN (contextual)
|
||||
sink = graph_rewrite(sink, sym+just_reduce, ctx=[0])
|
||||
|
||||
# expand
|
||||
sink = graph_rewrite(sink, sym+expander)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user