no broadcasting/vectors in reduce collapse (#12729)

This commit is contained in:
Sieds Lykles
2025-10-16 13:22:57 +02:00
committed by GitHub
parent 533f18b22c
commit 8f740e07ff

View File

@@ -97,14 +97,12 @@ pm_reduce_collapse = PatternMatcher([
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)),
# reduce on gated load becomes can substitute the range and remove the reduce
((UPat.var("idx")!=(UPat(Ops.RANGE, name="r").or_casted())).where(0, UPat.var("expr")).reduce(UPat.var("r"), arg=Ops.ADD),
lambda r,idx,expr: (v:=(idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])).where(expr.substitute({r:idx.cast(r.dtype).valid(v)}),0)),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
# remove REDUCEs that no longer have a RANGE in the src
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),