From 1242b302fa187d5ce235324bf0fbe0bc634532bd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:17:50 -0700 Subject: [PATCH] expand UOps with rewrite rules (#5501) * expand UOps with rewrite rules [run_process_replay] * progress * much closer * close, way less bugs * bunch of expander tests * fix contract * ops tests pass * fix barrier * mostly passing * bitcast in expanded ops * support more expand merges * all tests pass maybe * fix empty EXPAND * fix LIN fuzzing * add ALL_SAME assert * all same * all same work * raise CompileError * pass fuzz linearizer * revert whitespace * fix nv tensor core test * fix mypy * bug fix * fuzzer passes * put tests back * expand arg to idx --- test/test_uop_graph.py | 21 +-- tinygrad/codegen/uopgraph.py | 266 ++++++++++++++++------------------- 2 files changed, 127 insertions(+), 160 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 3cfd4dcea6..ff3ad8ec16 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -270,15 +270,14 @@ class TestUOpGraph(TestUOps): self.assertEqual(endranges[-1].src[0], ranges[0]) def expander_rewrite(sink): - #from tinygrad.codegen.uopgraph import expander, constant_folder - #together = PatternMatcher(expander.patterns + constant_folder.patterns) - #return graph_rewrite(sink, together) - out = UOpGraph(UOp(UOps.SINK, None, (sink,))) - out.linearize() - return out.uops[-1] + from tinygrad.codegen.uopgraph import expander, constant_folder + together = PatternMatcher(expander.patterns + constant_folder.patterns) + return graph_rewrite(sink, together) + #out = UOpGraph(UOp(UOps.SINK, None, (sink,))) + #out.linearize() + #return out.uops[-1] class TestExpander(unittest.TestCase): - @unittest.skip def test_expand_add_broadcast(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) sink = expander_rewrite(e1+3) @@ -292,7 +291,6 @@ class TestExpander(unittest.TestCase): assert sink.op is UOps.VECTORIZE and len(sink.src) == 4 self.assertListEqual([x.arg for x in sink.src], [0,1,2,3]) - @unittest.skip def test_contract_axis_1(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,)) @@ -302,7 +300,6 @@ class TestExpander(unittest.TestCase): self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12]) self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15]) - @unittest.skip def test_contract_axis_2(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,)) @@ -312,7 +309,6 @@ class TestExpander(unittest.TestCase): self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3]) self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15]) - @unittest.skip def test_contract_mid(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2))) con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,)) @@ -324,7 +320,6 @@ class TestExpander(unittest.TestCase): self.assertListEqual([x.arg for x in sink.src[2].src], [4,6]) self.assertListEqual([x.arg for x in sink.src[3].src], [5,7]) - @unittest.skip def test_expand_same_axis(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),)) @@ -332,7 +327,6 @@ class TestExpander(unittest.TestCase): assert sink.op is UOps.EXPAND and len(sink.src) == 4 self.assertListEqual([x.arg for x in sink.src], [0,5,10,15]) - @unittest.skip def test_expand_different_axis(self, flip=False): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) @@ -357,7 +351,6 @@ class TestExpander(unittest.TestCase): assert sink.op is UOps.CONST self.assertEqual(sink.arg, 3*4) - @unittest.skip def test_double_expand(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),)) @@ -367,7 +360,6 @@ class TestExpander(unittest.TestCase): assert sink.arg == ((1, 2), (2, 4)) self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7]) - @unittest.skip def test_double_expand_reverse(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),)) @@ -377,7 +369,6 @@ class TestExpander(unittest.TestCase): assert sink.arg == ((1, 4), (2, 2)) self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7]) - @unittest.skip def test_double_expand_middle(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2))) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2))) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index a73e782169..deb231e496 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -5,7 +5,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, DType, PtrDType, ImageDType from tinygrad.shape.symbolic import Variable from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu -from tinygrad.helpers import DEBUG, getenv, flatten, all_same, dedup, TRANSCENDENTAL, CI +from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, prod, CI from tinygrad.codegen.uops import UOp, UOps, END_FOR_UOP, type_verify from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -68,122 +68,7 @@ class PatternMatcher: if (matches := _match(uop, p, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match return None -def expand_nodes(parents:Set[UOp], expands:List[UOp], base:UOp) -> List[UOp]: - # just in case, dedup expands - expands = dedup(expands) - - # get children and define_accs - children = defaultdict(list) - define_accs = [] - for p in parents: - if p.op is UOps.PHI: - wmma_reduce_axes = flatten([x.arg[7] for x in p.parents if x.op is UOps.WMMA]) - parent_expands_for_acc = [x.arg[0][0] for x in p.parents if x in expands and x.arg[0][0] not in wmma_reduce_axes] - define_accs.append((p.src[0], parent_expands_for_acc)) - for x in p.src: - children[x].append(p) - - # get nodes on the path from root to the expand node - on_path: Dict[UOp, None] = {} - search = expands[:] - while len(search): - t = search.pop(0) - for cc in children[t]: - if cc in on_path: continue - on_path[cc] = None - search.append(cc) - - # toposort the nodes on the path - # TODO: library! - in_degree: DefaultDict[UOp, int] = defaultdict(int) - for n in on_path: - for x in children[n]: - in_degree[x] += 1 - toposort: List[UOp] = [] - search2 = [p for p in on_path if in_degree[p] == 0] - seen: Set[UOp] = set() - while len(search2): - n = search2.pop(0) - if n in seen: continue - toposort.append(n) - for x in children[n]: - in_degree[x] -= 1 - if in_degree[x] == 0: - search2.append(x) - - # get replacements by index - replacements: Dict[int, List[int]] = {} - for r in expands: - if r.arg[0][0] in replacements: assert len(replacements[r.arg[0][0]]) == len(r.src) - else: replacements[r.arg[0][0]] = list(range(0, len(r.src))) - - # get nodes on the path from root to the expand node - rps = list(itertools.product(*replacements.values())) - - acc_number = 0 - replaces: List[Dict[UOp, UOp]] = [] - acc_cache: Dict[Tuple[Tuple[UOp, int, int], ...], UOp] = {} - for rp in rps: - rpk = dict(zip(replacements.keys(), rp)) - replace = {r:r.src[rpk[r.arg[0][0]]] for r in expands} - for d, acc_parents in define_accs: - acc_index = tuple((d,x,rpk[x]) for x in acc_parents) - if acc_index in acc_cache: - replace[d] = acc_cache[acc_index] - else: - replace[d] = acc_cache[acc_index] = UOp(d.op, d.dtype, d.src, d.arg + (acc_number,)) - acc_number += 1 - replaces.append(replace) - - for cc in toposort: - if cc.op is UOps.BARRIER: - super_replace = UOp(cc.op, cc.dtype, sum([tuple(replace.get(x, x) for x in cc.src) for replace in replaces], ()), cc.arg) - for replace in replaces: - replace[cc] = super_replace - else: - for replace in replaces: - tcc = replace.get(cc, cc) # NOTE: handle expands that are already replaced - replace[cc] = UOp(tcc.op, tcc.dtype, tuple(replace.get(x, x) for x in tcc.src), tcc.arg) - return [x.get(base, base) for x in replaces] - -# ***** reduce+image+contract handling ***** - -def expand_wmma(wmma): - expands = [x for x in wmma.parents if x.op is UOps.EXPAND and (x.arg[0][0] in wmma.arg[-1] or x.arg[0][0] in wmma.arg[-2])] - if len(expands) == 0: return None - new_uops = expand_nodes(wmma.sparents, expands, wmma) - # TODO: assert that these are all the same. they have to be - return new_uops[0] - -acc_number = 0 -def replace_reduce(root): - global acc_number - expands = [x for x in root.src[1:] if x.op is UOps.EXPAND] - - # add other expands for float4. TODO: should be a faster way - expand_args = [x.arg[0][0] for x in expands] - new_expands = [x for x in root.parents if x.op is UOps.EXPAND and x.arg[0][0] in expand_args] - expands = dedup(expands + new_expands) - - if len(expands): - new_uops = expand_nodes(root.parents, expands, root.src[0]) - else: - new_uops = [root.src[0]] - - const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar())) - acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x not in expands), (acc_number,)) - acc_number += 1 - ret = acc - for xx in new_uops: ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx) - return UOp(UOps.PHI, ret.dtype, (acc, ret)) - -def replace_contract(root:UOp): - parents, dtype = root.parents, cast(DType, root.dtype) - expands: List[UOp] = [x for x in parents if x.op is UOps.EXPAND and x.arg[0][0] in root.arg] - assert all_same(expand_lens := [dtype.count] + [len(x.src) for x in expands]), expand_lens - ret = expand_nodes(parents, expands, root.src[0]) - if len(ret) == 1: ret = ret*dtype.count # TODO: why is this needed? - return UOp(UOps.VECTORIZE, dtype, tuple(ret)) +# ***** image handling ***** def fix_image_idx(ls:UOp): if ls.src[1].dtype is None or ls.src[1].dtype.count != 1: return None @@ -203,25 +88,6 @@ def fix_image_idx(ls:UOp): return ret return UOp(ls.op, ls.dtype, (ls.src[0], image_idx) + ls.src[2:], ls.arg) -def cast_reduce(cst): - if cst.dtype.scalar() == cst.dtype: return None # not for normal CAST. TODO: the merging one shouldn't be CAST - if not all_same([(x.arg, x.src[1:]) for x in cst.src]): return None - fst_red = cst.src[0] - red = UOp(UOps.VECTORIZE, cst.dtype, tuple(x.src[0] for x in cst.src)) - return UOp(UOps.REDUCE, red.dtype, (red,) + fst_red.src[1:], fst_red.arg) - -contractor = PatternMatcher([ - # contracts - (UOp(UOps.CONTRACT).name("root"), replace_contract), -]) - -reducer = PatternMatcher([ - (UOp(UOps.REDUCE).name("root"), replace_reduce), - (UOp(UOps.WMMA).name("wmma"), expand_wmma), - # image indexing. TODO: why can't this just go after the float stuff? - (UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx), -]) - # ***** float4 handling ***** def float4_expand_load(load, buf, ex, idx=UOp.const(dtypes.int, 0), idx2=None): @@ -483,6 +349,123 @@ constant_folder = PatternMatcher([ constant_folder_w_f4 = PatternMatcher(constant_folder.patterns + float4_folding.patterns) +# *** uop expander *** + +def _expand_arg_to_idx(arg:Tuple[Tuple[int, int], ...], rpk:Dict[int, int]): + idx, mul = 0, 1 + for axis,m in arg[::-1]: + idx += rpk[axis] * mul + mul *= m + return idx + +def do_expand(root:UOp): + if root.op is UOps.REDUCE: + if root.src[0].op is not UOps.EXPAND: return None + reduce_expand_args = flatten([x.arg for x in root.src[1:] if x.op is UOps.EXPAND]) + expand_args = tuple(x for x in root.src[0].arg if x not in reduce_expand_args) + if len(expand_args) == 0: return None + dont_expand_args = tuple(x for x in root.src[0].arg if x in reduce_expand_args) + else: + expands = [x for x in root.src if x.op is UOps.EXPAND] + if len(expands) == 0: return None + expand_args = tuple(sorted(dedup(flatten([x.arg for x in expands])))) + if root.op is UOps.WMMA: + dont_expand_args = tuple(x for x in expand_args if x[0] in root.arg[-1] or x[0] in root.arg[-2]) + expand_args = tuple(x for x in expand_args if x not in dont_expand_args) + else: + dont_expand_args = () + new_srcs = [] + for choices in itertools.product(*[range(x[1]) for x in expand_args]): + rpk = dict(zip([x[0] for x in expand_args], choices)) + new_src = [] + for src in root.src: + if src.op is UOps.EXPAND: + lnew_src = [] + for lchoices in itertools.product(*[range(x[1]) for x in dont_expand_args]): + lrpk = {**rpk, **dict(zip([x[0] for x in dont_expand_args], lchoices))} + lnew_src.append(src.src[_expand_arg_to_idx(src.arg, lrpk)]) + if len(dont_expand_args): + if root.op is UOps.WMMA: + new_src.append(lnew_src[0]) # TODO: is this always right? + else: + new_src.append(UOp(UOps.EXPAND, root.dtype, tuple(lnew_src), dont_expand_args)) + else: + assert len(lnew_src) == 1 + new_src.append(lnew_src[0]) + else: + new_src.append(src) + new_srcs.append(UOp(root.op, root.dtype, tuple(new_src), root.arg)) + if root.op is UOps.EXPAND: + expand_args, old_expand_args = tuple(sorted(root.arg+expand_args)), expand_args + assert len(expand_args) == (len(old_expand_args) + len(root.arg)) + new_new_srcs = [] + for choices in itertools.product(*[range(x[1]) for x in expand_args]): + rpk = dict(zip([x[0] for x in expand_args], choices)) + new_new_srcs.append(new_srcs[_expand_arg_to_idx(old_expand_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)]) + new_srcs = new_new_srcs + assert prod([x[1] for x in expand_args]) == len(new_srcs) + return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args) + +acc_number = 0 +def do_reduce_with_expand(root): + global acc_number + expands = [x for x in root.src[1:] if x.op is UOps.EXPAND] + expands_reduce = [x for x in expands if root.src[0].op is UOps.EXPAND and all(y in root.src[0].arg for y in x.arg)] + expands_non_reduce = [x for x in expands if x not in expands_reduce] + const = UOp.const(root.dtype.scalar(), dtypes.as_const(0, root.dtype.scalar()) if root.arg is ReduceOps.SUM else dtypes.min(root.dtype.scalar())) + ret = acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(x for x in root.src[1:] if x.op is not UOps.EXPAND), (acc_number,)) + acc_number += 1 + if len(expands_reduce): + assert root.src[0].op is UOps.EXPAND + expand_reduce_args = dedup(flatten([x.arg for x in expands_reduce])) + assert prod([y[1] for y in expand_reduce_args]) == len(root.src[0].src) + for xx in root.src[0].src: + ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, xx) + else: + ret = UOp.alu({ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[cast(ReduceOps, root.arg)], ret, root.src[0]) + ret = UOp(UOps.PHI, ret.dtype, (acc, ret)) + if len(expands_non_reduce): ret = ret * prod([sz for _,sz in flatten([x.arg for x in expands_non_reduce])]) + return ret + +def do_contract(con:UOp): + ex = con.src[0] + assert con.dtype is not None + # CONTRACT without EXPAND repeats the element VECTORIZED + if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) + # simple CONTRACT and EXPAND cancel out + if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src) + # complex CONTRACT may only remove one axis from EXPAND + assert len(con.arg) == 1, "contract arg one is all that's supported" + try: + split_index = [x[0] for x in ex.arg].index(con.arg[0]) + except ValueError: + # CONTRACT without EXPAND (still) repeats the element VECTORIZED + return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) + assert con.dtype.count == ex.arg[split_index][1], "contract arg must match" + number_after = prod([x[1] for x in ex.arg[split_index+1:]]) + to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)] + srcs = [] + for i in range(0, len(to_join), con.dtype.count): + srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])] + return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0])) + +expander = PatternMatcher([ + (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, + UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF}, name="root"), do_expand), + (UOp(UOps.REDUCE).name("root"), do_reduce_with_expand), + (UOp(UOps.CONTRACT).name("con"), do_contract), + # remove EXPANDs from SINK + (UOp(UOps.SINK).name("root"), + lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) + if len(a:=tuple(flatten(x.src if x.op is UOps.EXPAND else (x,) for x in root.src))) != len(root.src) else None), + # BARRIERs aren't actually expanded + (UOp(UOps.BARRIER, src=(UOp(UOps.EXPAND).name("ex"),)), lambda ex: UOp(UOps.EXPAND, None, (UOp(UOps.BARRIER, None, ex.src),)*len(ex.src), ex.arg)), + # image indexing (needs to be here) + (UPat({UOps.LOAD, UOps.STORE}, name="ls"), fix_image_idx), + # empty EXPAND is NOOP + (UOp(UOps.EXPAND, src=(UOp.var('x'),), arg=()), lambda x: x), +]) + # *** uop graph *** def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], in_degree:Dict[UOp, int]): @@ -565,16 +548,9 @@ class UOpGraph: sink = graph_rewrite(sink, self.folder) if extra_pm: sink = graph_rewrite(sink, PatternMatcher(self.folder.patterns+extra_pm.patterns)) + # expand UOpGraph.cnt += 1 - if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): - # do contracts/reduces - sink = graph_rewrite(sink, contractor) - sink = graph_rewrite(sink, reducer) - - # do upcasts (after reduce unrolls and rewrites) - expands = list(sorted(x for x in sink.sparents if x.op is UOps.EXPAND)) - new_nodes = expand_nodes(sink.sparents, expands, sink) - sink = UOp(UOps.SINK, None, tuple(flatten([x.src for x in new_nodes]))) # merge the sinks + if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, expander) # do graph rewrite (2) sink = graph_rewrite(sink, self.folder)