late merging of where and load (#12694)

This commit is contained in:
Sieds Lykles
2025-10-15 13:33:06 +02:00
committed by GitHub
parent 768dc952de
commit 91ac4f1f92
3 changed files with 11 additions and 9 deletions

View File

@@ -9,7 +9,7 @@ from tinygrad.renderer import Renderer
# import all pattern matchers here
from tinygrad.codegen.quantize import pm_quant
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
@@ -62,7 +62,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
# ** expander (expand_rewrite) **
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
ret.append(RewriteStep(sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic"))
# expand
ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander"))

View File

@@ -1,5 +1,5 @@
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
from tinygrad.uop.symbolic import symbolic_flat, sym, invalid_pat
from tinygrad.uop.symbolic import symbolic_flat, sym
from tinygrad.helpers import partition
from tinygrad.dtype import dtypes
@@ -100,9 +100,8 @@ pm_reduce_collapse = PatternMatcher([
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
# reduce on gated load becomes can substitute the range and remove the reduce
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)).load()
.reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,r,idx,expr,i:
buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))).load()),
((UPat.var("idx")!=(UPat(Ops.RANGE, name="r").or_casted())).where(0, UPat.var("expr")).reduce(UPat.var("r"), arg=Ops.ADD),
lambda r,idx,expr: (v:=(idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])).where(expr.substitute({r:idx.cast(r.dtype).valid(v)}),0)),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),

View File

@@ -473,6 +473,7 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None
return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i)
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
def where_on_load(l, c1, buf, x):
c2 = x.get_valid()
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
@@ -484,6 +485,11 @@ def where_on_load(l, c1, buf, x):
# aditionally we can drop the clause on the where if it already exists in the load
remaining_clause = functools.reduce(operator.and_, [c for c in c1.split_uop(Ops.AND) if c not in removed], UOp.const(dtypes.bool, True))
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
pm_move_where_on_load = PatternMatcher([
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
])
pm_simplify_valid = PatternMatcher([
# simplify valid
@@ -529,9 +535,6 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
(UPat(Ops.BARRIER, name="root"),
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)