mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
reduce with a mul chain (#9799)
* reduce with a mul chain * inside is just 1
This commit is contained in:
@@ -417,6 +417,17 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = (arange_augrng<UPat.cvar("compval")).where(UPat.const(None, 0), UPat.cvar("multconst"))
|
||||
|
||||
def reduce_mul_chain(r:UOp):
|
||||
if r.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if r.dtype != r.src[0].dtype: return None
|
||||
inside, outside = [], []
|
||||
for m in split_uop(r.src[0], Ops.MUL):
|
||||
m_parents = m.toposort
|
||||
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
|
||||
else: inside.append(m)
|
||||
if len(outside) == 0: return None
|
||||
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
@@ -487,7 +498,9 @@ sym = symbolic_flat+PatternMatcher([
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
|
||||
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
|
||||
# move const multiply after REDUCE
|
||||
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
|
||||
(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
|
||||
lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
|
||||
# reduce mul chain, move muls after the reduce
|
||||
(UPat(Ops.REDUCE, src=(UPat(Ops.MUL),), name="r", allow_any_len=True), reduce_mul_chain),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user