better where on load folding (#12651)

* move where clauses to load

* shorten line

* drop clauses if they are duplicated

* add rule for swapped where branch

* where on ungated load

* dont move clause if load is in the clause

* parse_valid returns None

* no data dependent branches

* fix rule

* enable swapped rule

* remove those
This commit is contained in:
Sieds Lykles
2025-10-14 13:30:47 +02:00
committed by GitHub
parent c7e63601fd
commit 852d80dff9
3 changed files with 19 additions and 11 deletions

View File

@@ -99,14 +99,10 @@ pm_reduce_collapse = PatternMatcher([
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx.valid(gate)).load()),
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx.valid(gate.logical_not())).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)),
lambda buf,r,idx,expr,i: buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])))),
# reduce on gated load becomes can substitute the range and remove the reduce
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)).load()
.reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,r,idx,expr,i:
buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))).load()),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),

View File

@@ -344,7 +344,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
@staticmethod
def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid)
def valid(self, cond): return cond.where(self, UOp.invalid(self.dtype.count))
def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count))
def get_idx(self) -> UOp:
assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype"
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self

View File

@@ -473,6 +473,17 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None
return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i)
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
def where_on_load(l, c1, buf, x):
c2 = x.get_valid()
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
# also no data dependent loads!
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
and not c.op_in_backward_slice_with_self(Ops.LOAD)]
if not (removed:=moved_clauses+duplicate_clauses): return None
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = functools.reduce(operator.and_, [c for c in c1.split_uop(Ops.AND) if c not in removed], UOp.const(dtypes.bool, True))
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
pm_simplify_valid = PatternMatcher([
# simplify valid
@@ -518,8 +529,9 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), name="l"), 0),
lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if all(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None),
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
(UPat(Ops.BARRIER, name="root"),
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)