mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
late merging of where and load (#12694)
This commit is contained in:
@@ -9,7 +9,7 @@ from tinygrad.renderer import Renderer
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
@@ -62,7 +62,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
|
||||
ret.append(RewriteStep(sym+migrate_indexing+pm_move_where_on_load, name="postopt symbolic"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+pm_group_for_reduce+expander, name="expander"))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
|
||||
from tinygrad.uop.symbolic import symbolic_flat, sym, invalid_pat
|
||||
from tinygrad.uop.symbolic import symbolic_flat, sym
|
||||
from tinygrad.helpers import partition
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@@ -100,9 +100,8 @@ pm_reduce_collapse = PatternMatcher([
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
||||
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
||||
# reduce on gated load becomes can substitute the range and remove the reduce
|
||||
(UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)).load()
|
||||
.reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,r,idx,expr,i:
|
||||
buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))).load()),
|
||||
((UPat.var("idx")!=(UPat(Ops.RANGE, name="r").or_casted())).where(0, UPat.var("expr")).reduce(UPat.var("r"), arg=Ops.ADD),
|
||||
lambda r,idx,expr: (v:=(idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])).where(expr.substitute({r:idx.cast(r.dtype).valid(v)}),0)),
|
||||
# AND on WHERE
|
||||
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
||||
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
|
||||
@@ -473,6 +473,7 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None:
|
||||
if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None
|
||||
return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i)
|
||||
pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)])
|
||||
|
||||
def where_on_load(l, c1, buf, x):
|
||||
c2 = x.get_valid()
|
||||
duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)]
|
||||
@@ -484,6 +485,11 @@ def where_on_load(l, c1, buf, x):
|
||||
# aditionally we can drop the clause on the where if it already exists in the load
|
||||
remaining_clause = functools.reduce(operator.and_, [c for c in c1.split_uop(Ops.AND) if c not in removed], UOp.const(dtypes.bool, True))
|
||||
return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0)
|
||||
pm_move_where_on_load = PatternMatcher([
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
|
||||
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
|
||||
])
|
||||
|
||||
pm_simplify_valid = PatternMatcher([
|
||||
# simplify valid
|
||||
@@ -529,9 +535,6 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0
|
||||
# # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer
|
||||
(UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load),
|
||||
(UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")),
|
||||
lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)),
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
|
||||
Reference in New Issue
Block a user