diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index a1217f3ac4..4845bc32ea 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -4,9 +4,8 @@ from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher -from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index -from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, load_store_indexing, sym, float4_folding, migrate_indexing +from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, reducer, sym, float4_folding, finalize, migrate_indexing from tinygrad.codegen.linearize import linearize_uop from tinygrad.shape.shapetracker import ShapeTracker, View @@ -446,10 +445,12 @@ class TestUOpGraph(unittest.TestCase): def expander_rewrite(sink): sink = graph_rewrite(sink, sym + expander) - return graph_rewrite(sink, sym + load_store_indexing) + sink = graph_rewrite(sink, sym + reducer) + return graph_rewrite(sink, sym + finalize) def float4_rewrite(sink): sink = graph_rewrite(sink, sym + migrate_indexing) - return graph_rewrite(sink, sym + expander + float4_folding) + sink = graph_rewrite(sink, sym + expander + float4_folding) + return graph_rewrite(sink, sym + finalize) class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): @@ -617,11 +618,11 @@ class TestLoadStoreFolder(unittest.TestCase): buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp(UOps.DEFINE_VAR, dtypes.bool) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] - sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)).sink() - sink = full_graph_rewrite(sink, Renderer()) + sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0] - self.assertEqual(single_load.src[1].op, UOps.VECTORIZE) + self.assertEqual(single_load.src[1].op, UOps.CONST) def test_simple_load_dont_fold_different_gated(self): buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) @@ -636,7 +637,7 @@ class TestLoadStoreFolder(unittest.TestCase): buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) - sink = full_graph_rewrite(sink, Renderer()) + sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 def test_simple_store_fold_gate(self): @@ -644,7 +645,7 @@ class TestLoadStoreFolder(unittest.TestCase): gate = UOp.variable("g1", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) - sink = full_graph_rewrite(sink, Renderer()) + sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 one_store = [x for x in sink.sparents if x.op is UOps.STORE][0] assert len(one_store.src) == 3 @@ -656,7 +657,8 @@ class TestLoadStoreFolder(unittest.TestCase): gate2 = UOp.variable("g2", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) - sink = full_graph_rewrite(sink, Renderer()) + sink = float4_rewrite(sink) + print(sink) assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3 class TestIFUOps(unittest.TestCase): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 196bd3a285..dc134c62db 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -127,15 +127,10 @@ transcendental_patterns = [ (UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=UnaryOps.SIN), xsin), ] -@functools.lru_cache(None) -def get_transcendental_patterns(ops, force_transcendental=False): - pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental] - return PatternMatcher(pat) - powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) -def get_extra_patterns(ops): - pat: List[Tuple[UPat, Callable]] = [] +def get_extra_patterns(ops, force_transcendental=False): + pat = [(p[0], cast(Callable, p[1])) for p in transcendental_patterns if p[0].arg not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) if BinaryOps.AND in ops: pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), @@ -457,23 +452,13 @@ devectorize = PatternMatcher([ (UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store), ]) -def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, store_gate:UOp) -> Optional[UOp]: - @functools.lru_cache(None) - def find_gate(x:UOp) -> Optional[UOp]: - if x.op is UOps.IF: return x - return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None) - if (gate:=find_gate(store)) is None or gate.src[0] is not store_gate: return None - return UOp.store(buf.index(idx), *store.src[1:]) - -load_store_indexing = PatternMatcher([ +reducer = PatternMatcher([ # late fixup of unfoldable image loads (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), # simplify valid (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), - # delete_redundant_gates (after expand) - (UPat(UOps.STORE, src=(UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")),), allow_any_len=True, name="store"), delete_redundant_gates), ]) def idx_load_store(x:UOp): @@ -496,39 +481,55 @@ def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx) return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:]) -pm_render = PatternMatcher([ - # renderers can't deal with VCONST or multiGEP - (UPat(UOps.CONST, name='c'), - lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), - (UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), - (UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), - # basic sym rule, don't vectorize size one - (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), +def delete_redundant_gates(root:UOp) -> Optional[UOp]: + @functools.lru_cache(None) + def find_gate(x:UOp) -> Optional[UOp]: + if x.op is UOps.IF: return x + return next((ret for s in x.src if (ret:=find_gate(s)) is not None), None) + if len(root.src) == 2 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[2]: return None + return UOp(UOps.STORE, root.dtype, root.src[:2], root.arg) + +finalize = PatternMatcher([ # move masks of loads/stores # TODO: this should be an IF instead of a masked STORE (UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))), masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask), + # delete_redundant_gates (after expand) + (UPat(UOps.STORE, name="root"), delete_redundant_gates), +]) + +# for rendering, we don't use vector +pm_render = PatternMatcher([ + (UPat(UOps.CONST, name='c'), + lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), + (UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), + (UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), + (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), ]) # *** uop graph *** def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" - supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () - extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) - # initial symbolic + migrate indexing (remove this) + early transcendental - sink = graph_rewrite(sink, sym+migrate_indexing+get_transcendental_patterns(supported_ops, TRANSCENDENTAL>=2)) + # temp for indexing migration + sink = graph_rewrite(sink, sym+migrate_indexing) - # convert EXPAND -> VECTORIZE + # expand sink = graph_rewrite(sink, sym+expander) - # convert REDUCE to DEFINE_ACC + ASSIGN (contextual, belongs in lowerer) + # convert REDUCE to DEFINE_ACC + ASSIGN (contextual) sink = graph_rewrite(sink, sym+just_reduce, ctx=[0]) - # devectorize + load/store indexing - sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing) + # devectorize + sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) - # final rules for the renderer (without sym) - sink = graph_rewrite(sink, pm_render+get_extra_patterns(supported_ops)+extra_matcher) + # cleanups + sink = graph_rewrite(sink, sym+reducer) + + # finalize + sink = graph_rewrite(sink, sym+finalize+get_extra_patterns(tuple(opts.code_for_op.keys()) if opts is not None else (), TRANSCENDENTAL>=2)) + + # for rendering without sym (including the rules from the renderer) + sink = graph_rewrite(sink, (pm_render+opts.extra_matcher if opts is not None and opts.extra_matcher is not None else pm_render)) return sink diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 9f7fe9998a..1f00e44391 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -47,7 +47,7 @@ ptx_matcher = symbolic+PatternMatcher([ (UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])), # load/store use pointer arithmetic, and the cast does nothing - (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), + (UPat(UOps.INDEX, name="x"), lambda x: x.src[0].cast(dtypes.int64) + x.src[1].cast(dtypes.int64)*x.src[0].dtype.itemsize), (UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), ])