From 8f740e07ff14770ffe13461cd0573d69b3eba179 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:22:57 +0200 Subject: [PATCH] no broadcasting/vectors in reduce collapse (#12729) --- tinygrad/codegen/simplify.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index c558f4a93c..b752ffd0cf 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -97,14 +97,12 @@ pm_reduce_collapse = PatternMatcher([ ((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)), # MUL casted bool - ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), - lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), + ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)), # reduce on gated load becomes can substitute the range and remove the reduce ((UPat.var("idx")!=(UPat(Ops.RANGE, name="r").or_casted())).where(0, UPat.var("expr")).reduce(UPat.var("r"), arg=Ops.ADD), lambda r,idx,expr: (v:=(idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])).where(expr.substitute({r:idx.cast(r.dtype).valid(v)}),0)), # AND on WHERE - ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ - .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), + ((UPat(Ops.DEFINE_VAR, name="x") & UPat.var("y")).where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), # remove REDUCEs that no longer have a RANGE in the src (UPat(Ops.REDUCE, name="red"), reduce_rangeless),