diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index bad5eaa797..513d35d461 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -2,7 +2,7 @@ import math, operator, struct, functools from collections import defaultdict from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu -from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace, can_safe_cast, Invalid +from tinygrad.dtype import ConstType, dtypes, PtrDType, can_safe_cast, Invalid from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING, unwrap from tinygrad.uop.decompositions import xpow @@ -450,7 +450,7 @@ def simplify_valid(valid:UOp) -> UOp|None: if ret[-1] is not stmt: something_changed = True return UOp.prod(*ret) if something_changed else None -# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ******** +# ******** phase 3 is the complete symbolic ******** def reduce_mul_chain(r:UOp): if r.arg not in {Ops.ADD, Ops.MAX}: return None @@ -479,6 +479,8 @@ def where_on_load(c1, buf, x): # aditionally we can drop the clause on the where if it already exists in the load remaining_clause = UOp.const(dtypes.bool, True).prod(*[c for c in c1.split_uop(Ops.AND) if c not in removed]) return remaining_clause.where(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2))), 0) + +# where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer pm_move_where_on_load = PatternMatcher([ (UPat.var("c1").where(UPat.var("buf").index(UPat.var("x")), 0), where_on_load), (UPat.var("c1").where(0, UPat.var("buf").index(UPat.var("x"))), lambda c1,buf,x: where_on_load(c1.logical_not(),buf,x)), @@ -494,9 +496,6 @@ pm_simplify_valid = PatternMatcher([ # this is symbolic 2.0 REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK} sym = symbolic+pm_simplify_valid+PatternMatcher([ - # LOAD/STORE -> NOOP - (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), - (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c), # VECTORIZE/GEP (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat.var("x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # reorder ALU/VECTORIZE @@ -525,7 +524,6 @@ sym = symbolic+pm_simplify_valid+PatternMatcher([ # fold gated LOAD/STORE (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 - # # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer ((UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()), # 1/(x^c) -> (1/x)^c ((UPat.var("x") * UPat.var("x") * UPat.var("x")).reciprocal(), lambda x: x.reciprocal()*x.reciprocal()*x.reciprocal()), ((UPat.var("x") * UPat.cvar("c")).reciprocal(), lambda x,c: x.reciprocal()*c.reciprocal()), # 1/(x*c) -> (1/c)*(1/x)