From 86d34daac9b35283fab4282cf3c0ff2b3f2d4e90 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:38:35 +0800 Subject: [PATCH] UOps.PHI -> UOps.ASSIGN [run_process_replay] (#6383) --- extra/backends/triton.py | 2 +- test/test_linearizer.py | 34 +++++++++++++++++----------------- tinygrad/codegen/uopgraph.py | 4 ++-- tinygrad/engine/graph.py | 2 +- tinygrad/ops.py | 4 ++-- tinygrad/renderer/assembly.py | 2 +- tinygrad/renderer/cstyle.py | 2 +- tinygrad/renderer/llvmir.py | 6 +++--- tinygrad/runtime/ops_python.py | 2 +- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/extra/backends/triton.py b/extra/backends/triton.py index a9248f72e9..a10f172e8f 100644 --- a/extra/backends/triton.py +++ b/extra/backends/triton.py @@ -90,7 +90,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}") elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}") elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args) - elif uop == UOps.PHI: + elif uop == UOps.ASSIGN: kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}") r[u] = r[vin[0]] elif uop == UOps.STORE: diff --git a/test/test_linearizer.py b/test/test_linearizer.py index e198a21823..e6858661bf 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -865,7 +865,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - # RANGE -> LOAD -> RANGE -> PHI + # RANGE -> LOAD -> RANGE -> ASSIGN #assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) def test_three_nested_range(self): @@ -874,7 +874,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now - # RANGE -> RANGE -> LOAD -> RANGE -> PHI + # RANGE -> RANGE -> LOAD -> RANGE -> ASSIGN # NOTE: nothing should toposort between the first two ranges #assert ranges[0]+1 == ranges[1] #assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]]) @@ -884,7 +884,7 @@ class TestLinearizer(unittest.TestCase): out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum() lin = helper_linearizer_opt(out, wanna_output=[24])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] - # RANGE -> ALU -> RANGE -> ALU + LOAD -> PHI + # RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]]) assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:]) @@ -895,7 +895,7 @@ class TestLinearizer(unittest.TestCase): out = (a + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] - # LOAD -> RANGE -> LOAD -> PHI + # LOAD -> RANGE -> LOAD -> ASSIGN assert lin.uops[ranges[0]-2].op is UOps.LOAD def test_range_outer_op_before_phi_nested_range(self): @@ -906,11 +906,11 @@ class TestLinearizer(unittest.TestCase): ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] assert len(ranges) == 1 # NOTE: it collapses now #if getenv("PTX"): - # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI + # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> ASSIGN # assert lin.uops[ranges[0]-2].op is UOps.LOAD # assert ranges[1] == ranges[0]+6 # assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU] - # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI + # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> ASSIGN #else: # assert lin.uops[ranges[0]-2].op is UOps.LOAD # assert ranges[1] == ranges[0]+3 @@ -920,7 +920,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(4, 1).realize() out = a.sum() * a.sum() lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] - # RANGE -> LOAD -> PHI -> ALU + # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) assert lin.uops[end+1].op is UOps.ALU @@ -928,7 +928,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] - # RANGE -> LOAD -> PHI -> ALU + # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) assert lin.uops[end+1].op is UOps.ALU @@ -1128,7 +1128,7 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: if u.op is UOps.WMMA: - assert u.src[-1].src[0].op != UOps.PHI + assert u.src[-1].src[0].op != UOps.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") @unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation") @@ -1140,12 +1140,12 @@ class TestLinearizer(unittest.TestCase): for u in k.uops: if u.op is UOps.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != UOps.PHI + assert u.src[-1].src[0].op != UOps.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") @unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation") def test_tensor_cores_unroll_casted_phi_with_children(self): - # all PHI children are outside the loop + # all ASSIGN children are outside the loop tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0] x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in) r = x.matmul(y, acc_dtype=tc.dtype_out).relu() @@ -1153,19 +1153,19 @@ class TestLinearizer(unittest.TestCase): for u in k.uops: if u.op is UOps.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != UOps.PHI + assert u.src[-1].src[0].op != UOps.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): x, y = Tensor.rand(128, 128), Tensor.rand(128, 128) r = (x@y).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1] - # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE + # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE for u in k.uops: - if u.op is UOps.PHI: + if u.op is UOps.ASSIGN: assert u.src[1].op is UOps.ALU - # children of PHI are placed after ENDRANGE - if any(x.op is UOps.PHI for x in u.src): + # children of ASSIGN are placed after ENDRANGE + if any(x.op is UOps.ASSIGN for x in u.src): end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0] assert end_range < k.uops.index(u) @@ -1292,7 +1292,7 @@ class TestLinearizer(unittest.TestCase): if if_op:=next((u for u in uops if u.op is UOps.IF), None): uops = uops[:uops.index(if_op)] assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both" - assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified" + assert len([u for u in uops if u.op is UOps.ASSIGN]) == 0, "ASSIGN should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 166055cf26..08678fea07 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -386,7 +386,7 @@ def do_reduce(root:UOp): acc = UOp(UOps.DEFINE_ACC, root.dtype, (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(reduce_parented), (acc_number,)) acc_number += 1 - ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret))) + ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret))) # for MAX, we can just ignore the unparented if root.arg is BinaryOps.ADD: for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype) @@ -512,7 +512,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # scope children impact the toposort and END* insertion scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} - range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE} + range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE} queue:List[Tuple[int, UOp]] = [] def push(u:UOp): diff --git a/tinygrad/engine/graph.py b/tinygrad/engine/graph.py index 88c5ad6026..643d22fd99 100644 --- a/tinygrad/engine/graph.py +++ b/tinygrad/engine/graph.py @@ -73,7 +73,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False): uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484", - UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} + UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} graph_uops_cnt = 0 def graph_uops(uops:List[UOp]): global graph_uops_cnt diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5fe880a3fb..a9d4c2f44c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -272,7 +272,7 @@ class UOps(Enum): - Gate UOp, can only return `dtypes.bool`. We rewrite this to an IF block in the end. - **`arg`**: `None` """ - PHI = auto() + ASSIGN = auto() # control flow ops BARRIER = auto() """ @@ -321,7 +321,7 @@ class UOps(Enum): BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} -END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)} +END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} @dataclass(frozen=True, eq=False) class UOp(MathTrait): diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index b8349c4a71..f54fa19896 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -231,7 +231,7 @@ class PTXRenderer(Renderer): else: kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[3]] if has_gate else None, alt=r[src[2]] if has_gate else None, ss=mem_type, offset=src[1].arg)) - elif uop is UOps.PHI: + elif uop is UOps.ASSIGN: if dtype.count > 1: for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0c812dc555..4e184dc788 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -156,7 +156,7 @@ class CStyleLanguage(Renderer): # NOTE: this relies on the load not happening if it's in the unselected branch if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") - elif uop is UOps.PHI: + elif uop is UOps.ASSIGN: kk(f"{r[src[0]]} = {r[src[1]]};") r[u] = r[src[0]] elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 774afb8df1..ff219b2571 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -131,11 +131,11 @@ class LLVMRenderer(Renderer): else: val = bb[-1].load(bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)) lvars[u] = val - elif uop is UOps.PHI: + elif uop is UOps.ASSIGN: lvars[u] = lvars[src[1]] - # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC + # ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC backward = src[0] - while backward.op is UOps.PHI: backward = backward.src[0] + while backward.op is UOps.ASSIGN: backward = backward.src[0] lvars[backward] = lvars[u] elif uop is UOps.ALU: lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 0a6f60b68c..222f393eb5 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -123,7 +123,7 @@ class PythonProgram: ul[i] = [load([inp[i][j] if dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)] else: ul[i] = load(inp) - elif uop is UOps.PHI: + elif uop is UOps.ASSIGN: for j in range(len(inp[0])): inp[0][j] = inp[1][j] ul[i] = inp[0] elif uop is UOps.GEP: