From fa14f7b4fdc19dbd394001026bb03c53e10dc582 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 23 Jul 2024 18:08:33 -0700 Subject: [PATCH] switch contract arg to match expand arg [run_process_replay] (#5667) * switch contract arg to match expand arg [run_process_replay] * support multiaxis contract too, it's easy * cancel contract/expand --- test/test_uop_graph.py | 15 +++++++-------- tinygrad/codegen/lowerer.py | 4 ++-- tinygrad/codegen/uopgraph.py | 28 +++++++++++----------------- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 7e831709ed..2f598de887 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -285,14 +285,14 @@ class TestExpander(unittest.TestCase): def test_contract_simple(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) assert sink.op is UOps.VECTORIZE and len(sink.src) == 4 self.assertListEqual([x.arg for x in sink.src], [0,1,2,3]) def test_contract_axis_1(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),) assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4 @@ -301,7 +301,7 @@ class TestExpander(unittest.TestCase): def test_contract_axis_2(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),)) sink = expander_rewrite(con) assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),) assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4 @@ -310,25 +310,24 @@ class TestExpander(unittest.TestCase): def test_contract_axis_2_big(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2)) self.assertListEqual([x.arg for x in sink.src[0].src], [0,4]) self.assertListEqual([x.arg for x in sink.src[6].src], [10,14]) - @unittest.skip("TODO: add support for this") def test_contract_multi_axis(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (3,2))) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3,2),(2,2)))) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6]) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,3))) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,2),(3,2)))) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6]) def test_contract_mid(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,)) + con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2)) assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2 diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 4a1f564895..8717816b7c 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -173,8 +173,8 @@ class IndependentLowerer: if x.op is ReduceOps.WMMA: wmma_sz, upcast_axis = x.arg[4], x.arg[6] ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=( - UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)), - UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)), + UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)), + UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)), UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),)) # NOTE: always using ridxs is fine here diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d24b5a647c..51d24598bb 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -23,7 +23,7 @@ def image_contract_load(buf, idx, idy, id4, ls_allow_any_len): ls_allow_any_len.const(float('nan'))) def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var): - new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],)) + new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), ((ex.arg[0][0],4),)) return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls_allow_any_len.src[3:]) # ***** float4 handling ***** @@ -47,7 +47,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype if idx3 is not None: idx = idx + idx3 if not idx.divides(len(ex.src)): return None - new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],)) + new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),)) return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:]) float4_folding = PatternMatcher([ @@ -379,23 +379,17 @@ def do_contract(con:UOp): ex = con.src[0] assert con.dtype is not None # CONTRACT without EXPAND repeats the element VECTORIZED - if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) - # simple CONTRACT and EXPAND cancel out - if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src) - # complex CONTRACT may only remove one axis from EXPAND - assert len(con.arg) == 1, "contract arg one is all that's supported" - try: - split_index = [x[0] for x in ex.arg].index(con.arg[0]) - except ValueError: - # CONTRACT without EXPAND (still) repeats the element VECTORIZED + if ex.op is not UOps.EXPAND or not all(x in ex.arg for x in con.arg): + assert ex.op is not UOps.EXPAND or not any(x in ex.arg for x in con.arg), "partial contract not supported" return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) - assert con.dtype.count == ex.arg[split_index][1], "contract arg must match" - number_after = prod([x[1] for x in ex.arg[split_index+1:]]) - to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)] + # simple CONTRACT and EXPAND cancel out + if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg == con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src) + # complex CONTRACT may remove several axes from EXPAND srcs = [] - for i in range(0, len(to_join), con.dtype.count): - srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])] - return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0])) + for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)): + lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)] + srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs))) + return UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args) def no_vectorized_alu(alu): if alu.dtype.count == 1: return None