mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
better where on load folding (#12651)
* move where clauses to load * shorten line * drop clauses if they are duplicated * add rule for swapped where branch * where on ungated load * dont move clause if load is in the clause * parse_valid returns None * no data dependent branches * fix rule * enable swapped rule * remove those
This commit is contained in:
@@ -99,14 +99,10 @@ pm_reduce_collapse = PatternMatcher([
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
||||
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
||||
# WHERE on LOAD (works on max too)
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx.valid(gate)).load()),
|
||||
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx.valid(gate.logical_not())).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)),
|
||||
lambda buf,r,idx,expr,i: buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])))),
|
||||
# reduce on gated load becomes can substitute the range and remove the reduce
|
||||
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)).load()
|
||||
.reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,r,idx,expr,i:
|
||||
buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))).load()),
|
||||
# AND on WHERE
|
||||
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
||||
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
|
||||
@@ -344,7 +344,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||
@staticmethod
|
||||
def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid)
|
||||
def valid(self, cond): return cond.where(self, UOp.invalid(self.dtype.count))
|
||||
def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count))
|
||||
def get_idx(self) -> UOp:
|
||||
assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype"
|
||||
return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self
|
||||
|
||||
@@ -473,6 +473,17 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None
|
||||
return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i)
|
||||
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
|
||||
def where_on_load(l, c1, buf, x):
|
||||
c2 = x.get_valid()
|
||||
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
|
||||
# we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range
|
||||
# also no data dependent loads!
|
||||
moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges)
|
||||
and not c.op_in_backward_slice_with_self(Ops.LOAD)]
|
||||
if not (removed:=moved_clauses+duplicate_clauses): return None
|
||||
# aditionally we can drop the clause on the where if it already exists in the load
|
||||
remaining_clause = functools.reduce(operator.and_, [c for c in c1.split_uop(Ops.AND) if c not in removed], UOp.const(dtypes.bool, True))
|
||||
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
|
||||
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
# simplify valid
|
||||
@@ -518,8 +529,9 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), name="l"), 0),
|
||||
lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if all(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None),
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
|
||||
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
|
||||
Reference in New Issue
Block a user