reduce where that is cut from two sides (#12733)

* better rule

* correct pattern

* shorten line
This commit is contained in:
Sieds Lykles
2025-10-16 15:25:15 +02:00
committed by GitHub
parent cf9baeea61
commit 55db1b0e0e

View File

@@ -91,6 +91,8 @@ pm_reduce_collapse = PatternMatcher([
# fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(UPat.var("r"), arg=Ops.ADD),
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
(((UPat.var("r")<UPat.var("lower")).logical_not()&(UPat(Ops.RANGE, name="r")<UPat.var("upper"))).where(UPat.cvar("val"), 0).reduce(UPat.var("r"),
arg=Ops.ADD), lambda r,lower,upper,val: (upper.minimum(r.src[0])-lower.maximum(0)).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(UPat.var("r"), arg=Ops.ADD),
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
# REDUCE on ADD