reduce with a mul chain (#9799)

* reduce with a mul chain

* inside is just 1
This commit is contained in:
George Hotz
2025-04-09 12:42:32 +08:00
committed by GitHub
parent 78caf55154
commit bb18adb0d5

View File

@@ -417,6 +417,17 @@ index_load = UPat.var("buf").index(rng_aug).load(name="ld")
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = (arange_augrng<UPat.cvar("compval")).where(UPat.const(None, 0), UPat.cvar("multconst"))
def reduce_mul_chain(r:UOp):
if r.arg not in {Ops.ADD, Ops.MAX}: return None
if r.dtype != r.src[0].dtype: return None
inside, outside = [], []
for m in split_uop(r.src[0], Ops.MUL):
m_parents = m.toposort
if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m)
else: inside.append(m)
if len(outside) == 0: return None
return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside)
# this is symbolic 2.0
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
@@ -487,7 +498,9 @@ sym = symbolic_flat+PatternMatcher([
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")), lambda x,d: 1-d), # x*/(1+x) -> 1-1/(1+x)
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")*UPat.var("y")), lambda x,y,d: y*(1-d)),
(UPat.var("x") * ((1+UPat.var("x")).reciprocal().named("d")+UPat.var("y")), lambda x,y,d: (1-d)+x*y),
# move const multiply after REDUCE
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
(UPat(Ops.REDUCE, src=(UPat.var("x")*UPat.cvar("c", vec=False),), arg=Ops.ADD, name="r", allow_any_len=True),
lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
# reduce mul chain, move muls after the reduce
(UPat(Ops.REDUCE, src=(UPat(Ops.MUL),), name="r", allow_any_len=True), reduce_mul_chain),
])