diff --git a/test/test_uops.py b/test/test_uops.py index eb23f10b33..b29137a015 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -517,7 +517,7 @@ class TestUOpStr(unittest.TestCase): class TestUPatHelpers(unittest.TestCase): def test_location(self): - self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py") + self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "mixins.py") self.assertEqual(shared_spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py") test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 3e95d20f16..b16cdc6ee7 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element -from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate +from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic, invalid_gate from tinygrad.helpers import getenv, flatten, AMX, prod from tinygrad.renderer import Renderer @@ -61,7 +61,7 @@ def expand_index(buf:UOp, vec:UOp): if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx() # generate the individual indexes midx = graph_rewrite(UOp.sink(*[buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)]), - symbolic_flat+load_store_indexing, name=f"index_buf_{buf.arg}") + symbolic+load_store_indexing, name=f"index_buf_{buf.arg}") # extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) for i in range(vec.dtype.count): diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index a625dc2cf2..dfb2358654 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -1,6 +1,6 @@ import itertools from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType -from tinygrad.uop.symbolic import symbolic_flat +from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import partition, dedup from tinygrad.dtype import dtypes @@ -28,7 +28,7 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: s0, s1 = r0.src[0], r1.src[0] # do the merge new_range = r0.replace(src=(s0*s1,)) - nidx = graph_rewrite(u, _substitute+symbolic_flat+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, + nidx = graph_rewrite(u, _substitute+symbolic+pm_flatten_range, ctx={r0:new_range//s1, r1:new_range%s1}, name=f"check_merge_{r0.arg[0]}_{r1.arg[0]}") # check if it simplifies @@ -109,7 +109,7 @@ pm_reduce_collapse = pm_reduce_unparented + PatternMatcher([ lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)), # MUL casted bool ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast()), lambda x,gate: gate.where(x, 0)), -])+symbolic_flat +])+symbolic pm_reduce_load_collapse = pm_reduce_collapse + PatternMatcher([ # lift x+y out of reduce on ne diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 864fdf54c7..c20936200f 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -3,7 +3,7 @@ import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags -from tinygrad.uop.symbolic import symbolic_flat +from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify @@ -536,7 +536,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, DEBUG_RANGEIFY) - tsink = graph_rewrite(tsink, symbolic_flat+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse") # this does const folding + tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding, name="symbolic+reduce_collapse") # this does const folding tsink = graph_rewrite(tsink, pm_remove_bufferize, bottom_up=True, name="remove bufferize with cost function") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 99bd7f0fd2..13a7156211 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -382,13 +382,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), ])+gep_pushing -symbolic_flat = symbolic+PatternMatcher([ - # ** combine terms (opinionated) ** - (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y - # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue - ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), -]) - # ******** we take a small aside to "simplify_valid" to rewrite valids ******** def parse_valid(valid:UOp) -> tuple[UOp, bool, int]|None: @@ -503,7 +496,7 @@ pm_simplify_valid = PatternMatcher([ # this is symbolic 2.0 REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK} -sym = symbolic_flat+pm_simplify_valid+PatternMatcher([ +sym = symbolic+pm_simplify_valid+PatternMatcher([ # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c), @@ -553,4 +546,8 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([ if any(x.op in REMOVE_FROM_SINK_LIKE for x in root.src) else None), # remove END with empty NOOP (UPat(Ops.END, src=(UPat(Ops.NOOP, src=(), name="noop"),), allow_any_len=True), lambda noop:noop), + # ** combine terms (opinionated) ** + (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y + # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue + ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ])