From 17a2f744129f02372155c8aad25a2962b22fb4cb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:17:20 -0700 Subject: [PATCH] new style load/store folder (#5784) * remove old index reorder * new style folder * works better * dedup * one failure * this is fine now... * expander_rewrite * images broken, but all else should work * cleanups * make tests work with old * fix images * cleanups + bugfix * minor fixes * fix gated store folding * flip gate_creator and expander * fix gated store * remove unneeded rules * lines getting close * line count good --- test/test_linearizer.py | 7 +- test/test_uop_graph.py | 86 ++++++++++++++--- tinygrad/codegen/lowerer.py | 9 +- tinygrad/codegen/uopgraph.py | 181 +++++++++++++++++------------------ 4 files changed, 172 insertions(+), 111 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 814274b73d..249a2fed8f 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -607,7 +607,7 @@ class TestLinearizer(unittest.TestCase): assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4) assert stores[0].src[0].op is UOps.DEFINE_LOCAL # the second store is to gds with no upcasts - assert accs[1].dtype == stores[1].src[2].dtype == dtypes.float + assert stores[1].src[2].dtype == dtypes.float assert stores[1].src[0].op is UOps.DEFINE_GLOBAL @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @@ -954,7 +954,7 @@ class TestLinearizer(unittest.TestCase): barrier = [u for u in k.uops if u.op is UOps.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: - assert store.src[2].dtype == dtypes.float.vec(2) and store.src[2].op is not UOps.VECTORIZE + assert store.src[2].dtype.count > 1 and store.src[2].op is not UOps.VECTORIZE # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) @@ -1091,6 +1091,7 @@ class TestFloat4(unittest.TestCase): # the first conv dot product is aligned in a. If we upcast the output and reduce # dimension, then we could do float4 for only that one set of loads, but we currently # don't. + # UPDATE: now we do this fusion s = create_schedule([c.lazydata])[0] k = Kernel(s.ast) @@ -1098,7 +1099,7 @@ class TestFloat4(unittest.TestCase): k.upcast() k.linearize() - assert TestFloat4.count_float4(k) == (0, 1) + assert TestFloat4.count_float4(k) in {(0,1), (1,1)} def test_float4_noncontiguous(self): a = Tensor.rand(4, 2).realize() diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 8c929f6cca..ecb16fc460 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -5,7 +5,7 @@ from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher -from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, constant_folder +from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding simple_pm = PatternMatcher([ (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), @@ -82,7 +82,7 @@ class TestGraphRewrite(unittest.TestCase): b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1)) c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1)) d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1)) - outs = [2+a, 2+a+d+3+b+c+4] #, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)] + outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)] for out in outs: sink = graph_rewrite(out, constant_folder) print(sink) @@ -233,12 +233,12 @@ class TestUOpGraph(TestUOps): idx = UOp.const(dtypes.int, 0) ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2))) ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3))) - uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld0+ld1))]) + uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2)) + self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assert_equiv_uops(ld1, UOp.load(glbl2, idx, dtype=dtypes.int)) + self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) def test_fold_gated_load_local(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) @@ -248,12 +248,12 @@ class TestUOpGraph(TestUOps): barrier = UOp(UOps.BARRIER, None, (st, )) ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier)) ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.bool, True), UOp.const(dtypes.int, 3), barrier)) - uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld0+ld1))]) + uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assert_equiv_uops(ld0, UOp.const(dtypes.int, 2)) + self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assert_equiv_uops(ld1, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) + self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) def test_fold_gated_store(self): glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True)) @@ -289,12 +289,8 @@ class TestUOpGraph(TestUOps): # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) -def expander_rewrite(sink): - together = PatternMatcher(expander.patterns + constant_folder.patterns) - return graph_rewrite(sink, together) - #out = UOpGraph(UOp(UOps.SINK, None, (sink,))) - #out.linearize() - #return out.uops[-1] +def expander_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer) +def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + float4_folding) class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): @@ -424,5 +420,67 @@ class TestExpander(unittest.TestCase): sink = expander_rewrite(sink) print(sink) +class TestLoadStoreFolder(unittest.TestCase): + def test_simple_load_fold(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] + sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 + + def test_two_load_fold(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] + sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2 + + def test_simple_load_fold_gated(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + gate = UOp(UOps.DEFINE_VAR, dtypes.bool) + load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate, UOp.const(dtypes.float, i))) for i in range(4)] + sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 + single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0] + self.assertListEqual([src.arg for src in single_load.src[3].src], [0.0, 1.0, 2.0, 3.0]) + + def test_simple_load_dont_fold_different_gated(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") + gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2") + load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), gate if i == 0 else gate2, UOp.const(dtypes.float, i))) for i in range(4)] + sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3 + + def test_simple_store_fold(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)] + sink = UOp(UOps.SINK, None, tuple(load)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 + + def test_simple_store_fold_gate(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") + load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] + sink = UOp(UOps.SINK, None, tuple(load)) + sink = float4_rewrite(sink) + assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 + one_store = [x for x in sink.sparents if x.op is UOps.STORE][0] + assert len(one_store.src) == 4 + assert str(one_store.src[3]) == str(gate) # huh, why do i need str here? + + def test_simple_store_dont_fold(self): + buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1") + gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2") + load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] + sink = UOp(UOps.SINK, None, tuple(load)) + sink = float4_rewrite(sink) + print(sink) + assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3 + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 8ca8680ab7..a6420c7cbd 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -157,7 +157,14 @@ class IndependentLowerer: (x.arg.idx, x.arg.idx < self.output_count)) if x.op is BufferOps.LOAD: barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else () - return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier) + load_dtype = x.arg.dtype.scalar() + if idx.dtype == dtypes.int.vec(3): + # this should all simplify if there's consts for id4. if not, w/e + idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2] + vec_load = UOp(UOps.LOAD, load_dtype.vec(4), (buf, idx) + ((valid, UOp.const(load_dtype.vec(4), 0)) if has_valid else ()) + barrier) + return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load_dtype, (vec_load,), i)), + range(4), UOp.const(load_dtype, float('nan'))) + return UOp(UOps.LOAD, load_dtype, (buf, idx) + ((valid, UOp.const(load_dtype, 0)) if has_valid else ()) + barrier) # NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!) if x.arg.idx >= 0: for oidx, ridx in zip(self.idxs, self.ridxs): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 79e57f5fa8..53771be9c2 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -1,80 +1,79 @@ from __future__ import annotations -from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING +from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict import functools, itertools, heapq, math -from tinygrad.dtype import dtypes, PtrDType, ImageDType +from collections import defaultdict +from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.shape.symbolic import Variable from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, exec_alu -from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI +from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI, all_same from tinygrad.codegen.uops import UOp, NOp, UOps, UPat, PatternMatcher, END_FOR_UOP, type_verify from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer -# ***** image handling ***** +# ***** float4/image store handling ***** -def image_contract_load(buf, idx, idy, id4, ls): - # TODO: there's no contract on the gate, is this okay? - if len(ls.src) > 3: extra = (ls.src[2], UOp(UOps.VECTORIZE, ls.dtype.vec(4), (ls.src[3],)*4)) - else: extra = ls.src[2:] # NOTE: image load shouldn't have barrier and this shouldn't matter - vec_load = UOp(UOps.LOAD, ls.dtype.vec(4), (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy))) + extra) - return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, ls.dtype, (vec_load,), i)), range(4), ls.const(float('nan'))) - -def image_contract_store(buf, ex, idx, idy, ls, var): - new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), ((ex.arg[0][0],4),)) - return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls.src[3:]) - -# ***** float4 handling ***** - -def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None, idx3=None): - if len(ex.src) not in [2, 4]: return None - if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None +def fold_expanded(ex, buf): if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None - if idx2 is not None: idx = idx + idx2 - if idx3 is not None: idx = idx + idx3 - if idx.divides(len(ex.src)) is None: return None + new_srcs = dedup(list(ex.src)) + old_new_srcs = new_srcs[:] + is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType) - if load.dtype.scalar() != load.dtype: return None # how does this happen? - vec_load = UOp(UOps.LOAD, load.dtype.vec(len(ex.src)), (buf, idx)) - return UOp(UOps.EXPAND, load.dtype, tuple(UOp(UOps.GEP, load.dtype, (vec_load,), i) for i in range(len(ex.src))), ex.arg) + # first, extract all the relevant offsets + offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict) + for i,s in enumerate(new_srcs): + if (s.dtype is not None and s.dtype.count != 1) or (is_image and s.src[1].dtype != dtypes.int.vec(3)): continue + idx = s.src[1] if not is_image else s.src[1].src[2] # only id4 for image + if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg + elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg + else: root_src, arg = idx, 0 + # add idx and idy for image + if is_image: root_src = (s.src[1].src[0:2], root_src) + # add gates for gated + if len(s.src) >= 4: root_src = (s.src[2] if is_load else s.src[3], root_src) # maybe flip the gate and the const? + assert arg not in offsets_rootsrc[root_src] + offsets_rootsrc[root_src][arg] = i -def float4_contract_store(buf, ex, var, store, idx=UOp.const(dtypes.int, 0), idx2=None, idx3=None): - if len(ex.src) not in [2, 4]: return None - if tuple(x.arg for x in ex.src if x.op is UOps.CONST) != tuple(range(len(ex.src))): return None - if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None - if idx2 is not None: idx = idx + idx2 - if idx3 is not None: idx = idx + idx3 - if idx.divides(len(ex.src)) is None: return None + # then rewrite everything we can + used = set() + for rootsrc, offsets in offsets_rootsrc.items(): + for o in offsets: + for fold_length in [4] if is_image else [4, 2]: + if all((rootsrc,o+i) not in used and o+i in offsets for i in range(fold_length)): + load_1 = new_srcs[offsets[o]] + new_src = list(load_1.src) + if not is_image and not new_src[1].divides(fold_length): continue + # for images, we rewrite the index + if is_image: new_src[1] = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_src[1].src[0], new_src[1].src[1])) + if is_load: + # vectorize the const. if we flip const and gate it's nicer here too + if len(new_src) >= 4: + new_src[3] = UOp(UOps.VECTORIZE, load_1.dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[3] for i in range(fold_length))) + new_load = UOp(load_1.op, load_1.dtype.vec(fold_length), tuple(new_src), load_1.arg) + for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.GEP, load_1.dtype, (new_load,), i) + else: + new_src[2] = UOp(UOps.VECTORIZE, new_src[2].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[2] for i in range(fold_length))) + for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(load_1.op, None, tuple(new_src), load_1.arg) if i == 0 else None + for i in range(fold_length): used.add((rootsrc,o+i)) - new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),)) - return UOp(UOps.STORE, None, (buf, idx, new_var) + store.src[3:]) + # dedup expand for LOAD + if is_load and len(old_new_srcs) != len(ex.src): new_srcs = [new_srcs[old_new_srcs.index(s)] for s in ex.src] + # remove Nones for STORE + return UOp(ex.op, ex.dtype, tuple(x for x in new_srcs if x is not None), ex.arg) if len(used) else None + +def vectorize_reduce(vec:UOp): + if not all_same([(x.src[1:], x.arg) for x in vec.src]): return None + return UOp(UOps.REDUCE, vec.dtype, (UOp(UOps.VECTORIZE, vec.dtype, tuple(x.src[0] for x in vec.src)),) + vec.src[0].src[1:], vec.src[0].arg) + +def vectorize_alu(vec:UOp): + if not all_same([x.arg for x in vec.src]): return None + return UOp(vec.src[0].op, vec.dtype, tuple(UOp(UOps.VECTORIZE, cast(DType, vec.src[0].src[i].dtype).vec(cast(DType, vec.dtype).count), + tuple(x.src[i] for x in vec.src)) for i in range(len(vec.src[0].src))), vec.src[0].arg) float4_folding = PatternMatcher([ - # reorder index to bring const closer to store - (NOp(UOps.STORE, src=(NOp.var("buf"), NOp.var("idx")+ - (NOp(UOps.EXPAND, src=tuple(NOp.const(dtypes.int, i) for i in range(4)), name="ex")+NOp.var("idx2")), NOp.var("var")), name="store"), - lambda buf, store, idx, idx2, ex, var: UOp(UOps.STORE, store.dtype, (buf, idx+idx2+ex, var), store.arg)), - # float(2,4) load - (NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.EXPAND, name="ex")+NOp.var("idx")+NOp.var("idx2")+NOp.var("idx3")), name="load"), float4_expand_load), - (NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.EXPAND, name="ex")+NOp.var("idx")+NOp.var("idx2")), name="load"), float4_expand_load), - (NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.EXPAND, name="ex")+NOp.var("idx")), name="load"), float4_expand_load), - (NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.EXPAND, name="ex")), name="load"), float4_expand_load), - # float(2,4) store - # TODO: fold ADDs into one UOp and remove add chains - (NOp(UOps.STORE, src=(NOp.var("buf"), - NOp(UOps.EXPAND, name="ex")+NOp.var("idx")+NOp.var("idx2")+NOp.var("idx3"), NOp.var("var")), name="store", allow_any_len=True), - float4_contract_store), - (NOp(UOps.STORE, src=(NOp.var("buf"), - NOp(UOps.EXPAND, name="ex")+NOp.var("idx")+NOp.var("idx2"), NOp.var("var")), name="store", allow_any_len=True), - float4_contract_store), - (NOp(UOps.STORE, src=(NOp.var("buf"), - NOp(UOps.EXPAND, name="ex")+NOp.var("idx"), NOp.var("var")), name="store", allow_any_len=True), float4_contract_store), - (NOp(UOps.STORE, src=(NOp.var("buf"), - NOp(UOps.EXPAND, name="ex"), NOp.var("var")), name="store", allow_any_len=True), float4_contract_store), - # image handling - (NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.VECTORIZE, dtypes.int.vec(3), (NOp.var('idx'), NOp.var('idy'), - NOp.var('id4')))), name="ls", allow_any_len=True), image_contract_load), - (NOp(UOps.STORE, src=(NOp.var("buf"), NOp(UOps.VECTORIZE, dtypes.int.vec(3), (NOp.var('idx'), NOp.var('idy'), - NOp(UOps.EXPAND, src=tuple(NOp.const(dtypes.int, i) for i in range(4)), name="ex"))), NOp.var("var")), name="ls", allow_any_len=True), - image_contract_store), + (UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded), + (UPat({UOps.BARRIER, UOps.SINK}, src=UPat(UOps.STORE, src=(UPat(name="buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded), + (UPat(UOps.VECTORIZE, src=UPat(UOps.REDUCE), name="vec"), vectorize_reduce), + (UPat(UOps.VECTORIZE, src=UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}), name="vec"), vectorize_alu), ]) # ***** transcendental ***** @@ -108,13 +107,11 @@ def reduce_before_expand(reduce, expand, x): expands = flatten([x.arg for x in reduce.src[1:] if x.op is UOps.EXPAND]) if any(x in expands for x in expand.arg): return None red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg) - gep = tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)) - return UOp(expand.op, expand.dtype, gep, expand.arg) + return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg) def sum_collapse(phi_input, loop, val1, val2): for v1,v2 in [(val1, val2), (val2, val1)]: - if loop not in v1.parents: - return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+v1*(loop.src[1]-loop.src[0]).cast(v1.dtype) + if loop not in v1.parents: return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+v1*(loop.src[1]-loop.src[0]).cast(v1.dtype) return None def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None): @@ -136,13 +133,6 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce): # this is symbolic 2.0 constant_folder = PatternMatcher([ - # CONTRACT before ALU/REDUCE/CAST - (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.ALU, name="alu"),)), - lambda con, alu: UOp(alu.op, con.dtype, tuple(UOp(UOps.CONTRACT, x.dtype.vec(con.dtype.count), (x,), con.arg) for x in alu.src), alu.arg)), - (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.REDUCE, dtype={dtypes.half, dtypes.bfloat16, dtypes.float}, name="red"),)), - lambda con, red: UOp(UOps.REDUCE, con.dtype, (UOp(UOps.CONTRACT, con.dtype, red.src[0:1], con.arg),)+red.src[1:], red.arg)), - (UPat(UOps.CONTRACT, name="con", src=(UPat(UOps.CAST, dtype={dtypes.half, dtypes.bfloat16, dtypes.float}, src=(UPat(name="casted"),)),)), - lambda con, casted: UOp(UOps.CAST, con.dtype, (UOp(UOps.CONTRACT, casted.dtype.vec(con.dtype.count), (casted,), con.arg),))), # bigint is rewritten to int32 (UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.bigint, name="x"), lambda x: UOp(x.op, dtypes.int32, x.src, x.arg)), @@ -194,7 +184,6 @@ constant_folder = PatternMatcher([ # const rules (NOp(UOps.GEP, src=(NOp.cvar("c"),), name="root"), lambda root, c: root.const(c.arg)), (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)), - (UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)), # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed (NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, name="acc"), NOp.var("acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)), (NOp(UOps.PHI, src=(NOp(UOps.DEFINE_ACC, src=(NOp.cvar(),)), NOp.var("x"))), lambda x: x), @@ -307,8 +296,8 @@ constant_folder = PatternMatcher([ # remove NOOPs from SINK (NOp(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None), - # ** move add consts to end ** - #(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1), + # ** move add consts to end (NOTE: this is still happening before constant folding) ** + (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), (UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y'))), lambda x,c1,y: (x+y)+c1), ]) @@ -409,10 +398,19 @@ def no_vectorized_alu(alu): tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), i) for s in alu.src), alu.arg) for i in range(alu.dtype.count)) return UOp(UOps.VECTORIZE, alu.dtype, alus) +def create_gate(root:UOp) -> Optional[UOp]: + @functools.lru_cache(None) + def _gate_srcs(u:UOp, gate:UOp) -> UOp: + if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, None, (gate, u.src[-1])),), u.arg) + return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg) + return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret + expander = PatternMatcher([ + # create gate MUST BE BEFORE expander + (NOp(UOps.STORE, name="root"), create_gate), + # do expansion (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand), - (NOp(UOps.REDUCE, name="root"), do_reduce_with_expand), (NOp(UOps.CONTRACT, name="con"), do_contract), # remove EXPANDs from SINK (NOp(UOps.SINK, name="root"), @@ -422,19 +420,8 @@ expander = PatternMatcher([ (NOp(UOps.BARRIER, src=(NOp(UOps.EXPAND, name="ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)), # empty EXPAND is NOOP (NOp(UOps.EXPAND, src=(NOp.var('x'),), arg=()), lambda x: x), - # no ALU on vectorized dtypes - (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu), ]) -def create_gate(root:UOp) -> Optional[UOp]: - @functools.lru_cache(None) - def _gate_srcs(u:UOp, gate:UOp) -> UOp: - if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, None, (gate, u.src[-1])),), u.arg) - return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg) - return None if len(root.src) == 3 or (ret:=_gate_srcs(root, root.src[3])) is root else ret - -gate_creator = PatternMatcher([(NOp(UOps.STORE, name="root"), create_gate)]) - def delete_redundant_gates(root:UOp) -> Optional[UOp]: @functools.lru_cache(None) def find_gate(x:UOp) -> Optional[UOp]: @@ -443,7 +430,15 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]: if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg) -gate_folder = PatternMatcher([(NOp(UOps.STORE, name="root"), delete_redundant_gates)]) +reducer = PatternMatcher([ + (NOp(UOps.REDUCE, name="root"), do_reduce_with_expand), + # no ALU on vectorized dtypes + (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu), + # VECTORIZE a CONST is a CONST (eventually remove this rule) + (UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)), + # delete_redundant_gates (after expand, is this still needed?) + (NOp(UOps.STORE, name="root"), delete_redundant_gates), +]) # *** uop graph *** @@ -473,9 +468,7 @@ class UOpGraph: # used by linearizer self._uops: Optional[List[UOp]] = None self.opts = opts - # NOTE: gate folding must come after expand - gate_pms = gate_creator+gate_folder if opts is None or not opts.supports_float4 else gate_creator - self.folder = constant_folder+gate_pms if opts is None or not opts.supports_float4 else constant_folder+gate_pms+float4_folding + self.folder = constant_folder if TRANSCENDENTAL >= 2 or (opts is not None and TRANSCENDENTAL >= 1 and opts.device in {"CLANG", "LLVM"}): self.folder = self.folder + transcendental_folding @@ -513,7 +506,9 @@ class UOpGraph: # expand UOpGraph.cnt += 1 - if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, self.folder+expander+gate_folder) + if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): + sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander) + sink = graph_rewrite(sink, self.folder+expander+reducer) # for PTX only if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)