diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 267a37a9c3..c6e967dd7f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,4 +48,4 @@ repos: entry: env PYTHONPATH="." python3 -m pylint tinygrad/ language: system always_run: true - pass_filenames: false + pass_filenames: false \ No newline at end of file diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 7e78e4882c..26b12224d5 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -16,7 +16,7 @@ from tinygrad.ops import LazyOp, UnaryOps, BufferOps from test.helpers import is_dtype_supported def tuplize_uops(uops:List[UOp]) -> Tuple: - return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) + return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops]) device = Device[Device.DEFAULT] diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index 64e32271d7..2f7382d02f 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -12,11 +12,11 @@ def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int] paths: List[List[UOp]] = [] # TODO: express DEFINE_ACC and loop children conditions in the graph, builtin. for p in find_all_toposorts(graph, in_degree): - assert p[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}" + assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}" paths.append(path:=list(p[:-1])) for u in path: - if u.uop is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,))) - if u.uop is UOps.RANGE: + if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,))) + if u.op is UOps.RANGE: path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,))) return paths @@ -58,9 +58,9 @@ def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[ def recurse_paths(path:List[UOp]): for v, d in in_degree.items(): if d != 0 or v in visited: continue - if v.uop is UOps.DEFINE_ACC and any(l not in path for l in v.vin): continue + if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue for u in graph[v]: in_degree[u] -= 1 - if v.uop is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.vin), v) + if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v) else: path.append(v) visited.add(v) recurse_paths(path) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 4260967988..b4f0d68a92 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -38,7 +38,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi k = Linearizer(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() - assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered" + assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" np_c = np_a @ np_b if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 @@ -54,7 +54,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d k = Linearizer(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() - wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA]) + wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA]) tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) if ensure_triggered: assert wmmas > 0, "tensor core not triggered" @@ -94,8 +94,8 @@ class TestLinearizer(unittest.TestCase): b_t = Tensor.full(st.shape, 3).contiguous().realize() lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0] - stores = [u for u in lin.uops if u.uop is UOps.STORE] - mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]] + stores = [u for u in lin.uops if u.op is UOps.STORE] + mutable_bufs = [u for u in lin.uops if u.op is UOps.DEFINE_GLOBAL and u.arg[-1]] assert len(mutable_bufs) == len(stores) == 2 assert [u.arg[0] for u in mutable_bufs] == [0, 1] @@ -108,92 +108,92 @@ class TestLinearizer(unittest.TestCase): load_t = Tensor.full(load.st.shape, 1).contiguous().realize() k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1] - self.assertEqual(k.uops[-1].uop, UOps.ENDIF) - self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1])) + self.assertEqual(k.uops[-1].op, UOps.ENDIF) + self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.op is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1])) def test_two_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] if getenv("PTX"): # RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI assert ranges[1] == ranges[0]+4 - assert lin.uops[ranges[0]+3].uop is UOps.LOAD + assert lin.uops[ranges[0]+3].op is UOps.LOAD else: # RANGE -> LOAD -> RANGE -> PHI assert ranges[1] == ranges[0]+2 - assert lin.uops[ranges[0]+1].uop is UOps.LOAD + assert lin.uops[ranges[0]+1].op is UOps.LOAD def test_three_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] if getenv("PTX"): # RANGE -> RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI assert ranges[2] == ranges[1]+4 == ranges[0]+5 - assert lin.uops[ranges[1]+3].uop is UOps.LOAD + assert lin.uops[ranges[1]+3].op is UOps.LOAD else: # RANGE -> RANGE -> LOAD -> RANGE -> PHI assert ranges[2] == ranges[1]+2 == ranges[0]+3 - assert lin.uops[ranges[1]+1].uop is UOps.LOAD + assert lin.uops[ranges[1]+1].op is UOps.LOAD def test_two_nested_range_alt_indexing(self): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum() lin = helper_linearizer_opt(out, wanna_output=[24])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] if getenv("PTX"): # RANGE -> CAST ridx -> LOAD_INDEXING -> 4x ALU -> RANGE -> LOAD -> RANGE -> PHI assert ranges[1] == ranges[0]+6 - assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE + assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE else: # RANGE -> 4x ALU -> RANGE -> 9x ALU + 1x LOAD -> PHI assert ranges[1] == ranges[0]+5 - assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE + assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE def test_range_outer_op_before_phi(self): a = Tensor.randn(4, 1).realize() b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()[0]])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] # LOAD -> RANGE -> LOAD -> PHI - assert lin.uops[ranges[0]-2].uop is UOps.LOAD + assert lin.uops[ranges[0]-2].op is UOps.LOAD def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()[0]])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] if getenv("PTX"): # LOAD -> RANGE -> 3xLOAD_INDEXING -> LOAD -> ALU -> RANGE -> PHI - assert lin.uops[ranges[0]-2].uop is UOps.LOAD + assert lin.uops[ranges[0]-2].op is UOps.LOAD assert ranges[1] == ranges[0]+5 - assert [x.uop for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU] + assert [x.op for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU] # LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI else: - assert lin.uops[ranges[0]-2].uop is UOps.LOAD + assert lin.uops[ranges[0]-2].op is UOps.LOAD assert ranges[1] == ranges[0]+3 - assert [x.uop for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU] + assert [x.op for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU] def test_range_outer_op_after_phi(self): a = Tensor.randn(4, 1).realize() out = a.sum() * a.sum() lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] # RANGE -> LOAD -> PHI -> ALU - end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE) - assert lin.uops[end+1].uop is UOps.ALU + end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) + assert lin.uops[end+1].op is UOps.ALU def test_range_outer_op_after_phi_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] # RANGE -> LOAD -> PHI -> ALU - end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE) - assert lin.uops[end+1].uop is UOps.ALU + end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) + assert lin.uops[end+1].op is UOps.ALU @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -202,11 +202,11 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*ast) k.hand_coded_optimizations() k.linearize() - self.assertEqual(len(endifs:=[x for x in k.uops if x.uop is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.uop is UOps.IF])) - self.assertEqual(len(barriers:=[x for x in k.uops if x.uop is UOps.BARRIER]), 3) - self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].uop, UOps.STORE) + self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF])) + self.assertEqual(len(barriers:=[x for x in k.uops if x.op is UOps.BARRIER]), 3) + self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].op, UOps.STORE) self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+1], barriers[1]) - self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].uop, UOps.LOAD) + self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].op, UOps.LOAD) self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0])) self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0])) self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1])) @@ -245,13 +245,13 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*ast) k.hand_coded_optimizations() k.linearize() - local_buf = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL] - self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.uop is UOps.STORE and any([lb in u.vin for lb in local_buf])]), 3, \ + local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL] + self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.op is UOps.STORE and any([lb in u.src for lb in local_buf])]), 3, \ f"should have generated 3 BufferOps.STORE to the local buf but got {len(real_local_stores)}") - self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.uop is UOps.LOAD and any([lb in u.vin for lb in local_buf])]), 3, \ + self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.op is UOps.LOAD and any([lb in u.src for lb in local_buf])]), 3, \ f"should have generated 3 BufferOps.LOAD to the local buf but got {len(real_local_loads)}") - self.assertEqual((real_local_stores[1].vin[1].uop, real_local_stores[1].vin[1].arg), (UOps.CONST, 0)) - self.assertEqual((real_local_loads[1].vin[1].uop, real_local_loads[1].vin[1].arg), (UOps.CONST, 0)) + self.assertEqual((real_local_stores[1].src[1].op, real_local_stores[1].src[1].arg), (UOps.CONST, 0)) + self.assertEqual((real_local_loads[1].src[1].op, real_local_loads[1].src[1].arg), (UOps.CONST, 0)) x = Tensor.randn(3,27,32).realize() helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)]) @@ -261,9 +261,9 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*ast) k.upcast() k.linearize() - define_globals = [u for u in k.uops if u.uop is UOps.DEFINE_GLOBAL] - self.assertEqual(len([u for u in k.uops if u.uop is UOps.LOAD and define_globals[1] in u.vin]), 7) - self.assertEqual(len([u for u in k.uops if u.uop is UOps.ALU and u.arg is BinaryOps.ADD]), 25) + define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL] + self.assertEqual(len([u for u in k.uops if u.op is UOps.LOAD and define_globals[1] in u.src]), 7) + self.assertEqual(len([u for u in k.uops if u.op is UOps.ALU and u.arg is BinaryOps.ADD]), 25) opts = [[Opt(op=OptOps.UPCAST, axis=0, amt=2)], [Opt(op=OptOps.UPCAST, axis=0, amt=4)]] x = Tensor.randn(8,7).softmax().realize() helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)]) @@ -288,12 +288,12 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*ast) k.hand_coded_optimizations() k.linearize() - def get_recursive_children(x:UOp): return set.union(set(x.vin), *[get_recursive_children(v) for v in x.vin]) + def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src]) loop = None for u in k.uops: - if u.uop is UOps.RANGE: loop = u + if u.op is UOps.RANGE: loop = u elif loop is None: continue - elif u.uop is UOps.ENDRANGE and loop in u.vin: loop = None + elif u.op is UOps.ENDRANGE and loop in u.src: loop = None else: self.assertIn(loop, get_recursive_children(u), f"Any uop within a loop should depend on the loop: {u}") x = Tensor.randn(3, 27, 32).realize() helper_linearizer_ast(ast, [x], wanna_output= \ @@ -358,7 +358,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD]) + num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" @@ -377,7 +377,7 @@ class TestLinearizer(unittest.TestCase): lin.linearize() assert len(lin.uops.uops) <= 7, "too many uops" - a_bufs = [u.uop for u in lin.uops.uops[-1].vin[2].vin] + a_bufs = [u.op for u in lin.uops.uops[-1].src[2].src] assert a_bufs == [UOps.LOAD, UOps.CONST] def test_upcast_cse(self): @@ -389,7 +389,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) assert num_ops <= 1, "more alu uops than needed" @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -401,11 +401,11 @@ class TestLinearizer(unittest.TestCase): k.upcast() k.upcast() k.linearize() - accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC] - stores = [u for u in k.uops if u.uop is UOps.STORE] + accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC] + stores = [u for u in k.uops if u.op is UOps.STORE] assert len(accs) == 0 # it's removed now assert len(stores) == 1 - assert stores[0].vin[-1].dtype == dtypes.float.vec(4) + assert stores[0].src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -417,15 +417,15 @@ class TestLinearizer(unittest.TestCase): k.hand_coded_optimizations() k.linearize() - accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC] - stores = [u for u in k.uops if u.uop is UOps.STORE] + accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC] + stores = [u for u in k.uops if u.op is UOps.STORE] # the first store is to lds and can be upcasted - assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4) - assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL + assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4) + assert stores[0].src[0].op is UOps.DEFINE_LOCAL # the second store is to gds with no upcasts - assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float - assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL + assert accs[1].dtype == stores[1].src[-1].dtype == dtypes.float + assert stores[1].src[0].op is UOps.DEFINE_GLOBAL @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") def test_upcast_multireduce_nested_local_upcast(self): @@ -449,7 +449,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) assert num_ops == 0, "more alu uops than needed" def test_sum_acc_dtype(self): @@ -458,14 +458,14 @@ class TestLinearizer(unittest.TestCase): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() k = Linearizer(*create_schedule([a.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): k = Linearizer(*create_schedule([c.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] assert local[0].dtype == expected_dtype tests = ( @@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=2) k.linearize() - assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered" + assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" prg = CompiledRunner(k.to_program()) @@ -571,8 +571,8 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.uop is UOps.WMMA: - assert u.vin[-1].vin[0].uop != UOps.PHI + if u.op is UOps.WMMA: + assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_casted_phi(self): @@ -581,9 +581,9 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.uop is UOps.WMMA: - assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.vin[-1].vin[0].uop != UOps.PHI + if u.op is UOps.WMMA: + assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) + assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_casted_phi_with_children(self): @@ -593,9 +593,9 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.uop is UOps.WMMA: - assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.vin[-1].vin[0].uop != UOps.PHI + if u.op is UOps.WMMA: + assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) + assert u.src[-1].src[0].op != UOps.PHI @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): @@ -604,11 +604,11 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1] # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE for u in k.uops: - if u.uop is UOps.PHI: - assert u.vin[1].uop is UOps.ALU + if u.op is UOps.PHI: + assert u.src[1].op is UOps.ALU # children of PHI are placed after ENDRANGE - if any(x.uop is UOps.PHI for x in u.vin): - end_range = [i for i, x in enumerate(k.uops) if x.uop is UOps.ENDRANGE][0] + if any(x.op is UOps.PHI for x in u.src): + end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0] assert end_range < k.uops.uops.index(u) def test_grouped_dims(self): @@ -649,7 +649,7 @@ class TestLinearizer(unittest.TestCase): sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] assert len(sched) == 1 lin = Linearizer(*sched[0].ast) - assert not any(u.uop is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse" + assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse" def test_assign_fold(self): a = Tensor.ones(4, 4).contiguous().realize() @@ -680,10 +680,10 @@ class TestLinearizer(unittest.TestCase): k.hand_coded_optimizations() uops = list(k.linearize().uops) # ignore kernel optimized IF statements for now - if if_op:=next((u for u in uops if u.uop is UOps.IF), None): + if if_op:=next((u for u in uops if u.op is UOps.IF), None): uops = uops[:uops.index(if_op)] - assert len(set([u.uop for u in uops if u.uop in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both" - assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified" + assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both" + assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" @@ -708,17 +708,17 @@ class TestLinearizer(unittest.TestCase): out = x.matmul(y) k = helper_linearizer_opt(out)[-1] # check that the float4 cast collapses - store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE] + store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE] for val in store_vals: - assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST + assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_grouped_store_values(self): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() k = helper_linearizer_opt(out)[-1] - store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0] - assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST + store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0] + assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -729,16 +729,16 @@ class TestLinearizer(unittest.TestCase): opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces k = helper_linearizer_opt(out, opts=[opt])[-1] - def get_recursive(uop): return set.union(set(uop.vin), [uop], *[get_recursive(v) for v in uop.vin]) - local_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_LOCAL for x in get_recursive(u.vin[0]))] - global_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_GLOBAL for x in get_recursive(u.vin[0]))] - barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0] + def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) + local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))] + global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] + barrier = [u for u in k.uops if u.op is UOps.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: - assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST + assert store.src[-1].dtype == dtypes.float.vec(2) and store.src[-1].op is not UOps.CAST # check the children's vins - assert barrier.vin == tuple(local_stores) - assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1 + assert barrier.src == tuple(local_stores) + assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -747,14 +747,14 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = helper_linearizer_opt(r)[-1] - stores = [u for u in k.uops if u.uop is UOps.STORE] + stores = [u for u in k.uops if u.op is UOps.STORE] # the float4 value stores directly in lds and we skip upcast - assert stores[0].vin[-1].dtype == dtypes.float.vec(4) - assert stores[0].vin[-1].uop is not UOps.CAST + assert stores[0].src[-1].dtype == dtypes.float.vec(4) + assert stores[0].src[-1].op is not UOps.CAST # the global store doesn't change - assert stores[1].vin[-1].dtype == dtypes.float + assert stores[1].src[-1].dtype == dtypes.float @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts(self): @@ -764,8 +764,8 @@ class TestLinearizer(unittest.TestCase): Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2) ] k = helper_linearizer_ast(ast, [Tensor.empty(240*40).realize()], opts=[opt])[-1] - out = [u for u in k.uops if u.uop is UOps.STORE][0] - assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4) + out = [u for u in k.uops if u.op is UOps.STORE][0] + assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -775,15 +775,15 @@ class TestLinearizer(unittest.TestCase): Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] k = helper_linearizer_ast(ast, [Tensor.empty(8*32).realize()], opts=[opt])[-1] - out = [u for u in k.uops if u.uop is UOps.STORE][0] - assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2) + out = [u for u in k.uops if u.op is UOps.STORE][0] + assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): @staticmethod def count_float4(k): - return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), - len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)])) + return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), + len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)])) # TODO: express opts below as auto opts @@ -1178,10 +1178,10 @@ class TestKernelOpts(unittest.TestCase): for k in lins: seen_bar = False for u in k.uops: - if u.uop is UOps.BARRIER: + if u.op is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False + elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False @unittest.skip("TODO: broken") @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @@ -1200,10 +1200,10 @@ class TestKernelOpts(unittest.TestCase): for k in lins: seen_bar = False for u in k.uops: - if u.uop is UOps.BARRIER: + if u.op is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False + elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @@ -1225,10 +1225,10 @@ class TestKernelOpts(unittest.TestCase): for k in lins: seen_bar = False for u in k.uops: - if u.uop is UOps.BARRIER: + if u.op is UOps.BARRIER: assert not seen_bar, "redudant barrier" seen_bar = True - elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False + elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False def test_upcasts(self): N = 16 diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 35ce3fd514..3d2455ee09 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -6,7 +6,7 @@ from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat class TestPatternMatcher(unittest.TestCase): def assert_equiv_uops(self, uop1:UOp, uop2:UOp): # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops - self.assertEqual(uop1.uop, uop2.uop) + self.assertEqual(uop1.op, uop2.op) self.assertEqual(uop1.dtype, uop2.dtype) self.assertEqual(uop1.arg, uop2.arg) @@ -64,7 +64,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c3), c3) def test_dup_name(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)]) y1 = UOp(UOps.CONST, dtypes.float, arg=1.0) y2 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) @@ -91,13 +91,13 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_vin_one(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c2), None) - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)]) c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD) c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), None) @@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c5), None) def test_vin_permutations(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -118,7 +118,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c6), None) def test_vin_repeat(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=UPat(UOps.CONST)), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) @@ -127,7 +127,7 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c4), None) def test_allow_len(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)]) + matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c3 = UOp(UOps.CONST, dtypes.float, arg=3.0) @@ -157,7 +157,7 @@ class TestPatternMatcher(unittest.TestCase): matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int), lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))]) matcher.rewrite_graph(uops) - uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE})) + uops.remove_childless(set(x for x in uops if x.op in {UOps.STORE})) self.assertEqual(len(uops.uops), 3) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 6c157e4f34..45e8ec1040 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -12,7 +12,7 @@ class TestUOpGraph(unittest.TestCase): g = UOpGraph([out]) self.assertEqual(len(g.uops), 1) out = g.uops[-1] - self.assertEqual(out.uop, UOps.CONST) + self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 3.0) def test_where_same_fold(self): @@ -24,7 +24,7 @@ class TestUOpGraph(unittest.TestCase): g = UOpGraph([out]) self.assertEqual(len(g.uops), 1) out = g.uops[-1] - self.assertEqual(out.uop, UOps.CONST) + self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 1.0) def test_where_const_fold(self): @@ -35,7 +35,7 @@ class TestUOpGraph(unittest.TestCase): g = UOpGraph([out]) self.assertEqual(len(g.uops), 1) out = g.uops[-1] - self.assertEqual(out.uop, UOps.CONST) + self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 2.0) def test_const_cast(self): @@ -44,7 +44,7 @@ class TestUOpGraph(unittest.TestCase): g = UOpGraph([out]) self.assertEqual(len(g.uops), 1) out = g.uops[-1] - self.assertEqual(out.uop, UOps.CONST) + self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 0) def test_cast_vectorized_fold(self): @@ -56,7 +56,7 @@ class TestUOpGraph(unittest.TestCase): alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu)) g = UOpGraph([out]) - self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0) + self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0) def test_depth_2_const_fold(self): v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1)) @@ -67,10 +67,10 @@ class TestUOpGraph(unittest.TestCase): g = UOpGraph([out]) self.assertEqual(len(g.uops), 3) out = g.uops[-1] - self.assertEqual(out.uop, UOps.ALU) + self.assertEqual(out.op, UOps.ALU) self.assertEqual(out.arg, BinaryOps.ADD) - self.assertEqual(out.vin[1].uop, UOps.CONST) - self.assertEqual(out.vin[1].arg, 6) + self.assertEqual(out.src[1].op, UOps.CONST) + self.assertEqual(out.src[1].arg, 6) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_uops.py b/test/test_uops.py index 7c8ac096fb..9a5b8e5d18 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -226,7 +226,7 @@ class TestConstantFolding(unittest.TestCase): si = create_schedule([t.lazydata]) assert len(si) == 1 ji = lower_schedule_item(si[-1]) - assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast" + assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast" class TestGatedStoreRewrite(unittest.TestCase): @unittest.skip("not yet implemented") @@ -240,9 +240,9 @@ class TestGatedStoreRewrite(unittest.TestCase): gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT) uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))]) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) - if_uop = next(u for u in uops if u.uop is UOps.IF) - endif = next(u for u in uops if u.uop is UOps.ENDIF) - assert endif.vin[0] is if_uop + if_uop = next(u for u in uops if u.op is UOps.IF) + endif = next(u for u in uops if u.op is UOps.ENDIF) + assert endif.src[0] is if_uop nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) assert nested_uops == (gmem, gidx0, idx, value) @@ -261,9 +261,9 @@ class TestGatedStoreRewrite(unittest.TestCase): outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1))) uops = UOpGraph(outs) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) - if_uop = next(u for u in uops if u.uop is UOps.IF) - endif = next(u for u in uops if u.uop is UOps.ENDIF) - assert endif.vin[0] is if_uop + if_uop = next(u for u in uops if u.op is UOps.IF) + endif = next(u for u in uops if u.op is UOps.ENDIF) + assert endif.src[0] is if_uop nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) assert nested_uops == (gmem0, value0) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index ceeaf31696..fddfe2a31d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -29,23 +29,23 @@ class UOps(Enum): def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x @dataclass(eq=False) class UOp: - uop: UOps + op: UOps dtype: Optional[DType] = None - vin: Tuple[UOp, ...] = tuple() + src: Tuple[UOp, ...] = tuple() arg: Any = None - def tuple(self): return (self.uop, self.dtype, self.vin, self.arg) + def tuple(self): return (self.op, self.dtype, self.src, self.arg) def commutative(self) -> bool: - return self.uop is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR} + return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR} @functools.cached_property def cmp_tuple(self): # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX - return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \ - (type(self.uop), self.uop.value), self.dtype, self.vin) + return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ + (type(self.op), self.op.value), self.dtype, self.src) def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple def __repr__(self): - return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" + return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}" def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,)) - def name(self, name:Optional[str]): return UOp(UOps.VAR, vin=(self,), arg=name) + def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name) def __neg__(self): return UOp.alu(UnaryOps.NEG, self) def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x)) def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self) @@ -75,35 +75,35 @@ class UOp: @staticmethod def cvar(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name) @functools.cached_property - def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin]) + def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src]) @property # parents with self def sparents(self) -> Set[UOp]: return set([self]).union(self.parents) - def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.uop is UOps.DEFINE_VAR]) + def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR]) def uop_alu_resolve(u:UOp) -> sint: - if u.uop is UOps.CONST: return u.arg - if u.uop is UOps.DEFINE_VAR: return u.arg - if u.uop is UOps.SPECIAL: return u.arg[2]-1 - if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1]) - if u.uop is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.vin[0]) * (2**cast(int, uop_alu_resolve(u.vin[1]))) - if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1]) - raise RuntimeError(f"ALU resolve fail @ {u.uop}") + if u.op is UOps.CONST: return u.arg + if u.op is UOps.DEFINE_VAR: return u.arg + if u.op is UOps.SPECIAL: return u.arg[2]-1 + if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1]) + if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1]))) + if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1]) + raise RuntimeError(f"ALU resolve fail @ {u.op}") # *** simplification logic *** @dataclass(frozen=True) class UPat: - uop: Optional[Union[UOps, Set[UOps]]] = None + op: Optional[Union[UOps, Set[UOps]]] = None arg: Any = None - vin: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None + src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None name: Optional[str] = None dtype: Optional[Union[DType, Set[DType]]] = None allow_len: Set[int] = field(default_factory=set) @staticmethod def compile(u: UOp, name:Optional[str]=None) -> UPat: - if u.uop is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.vin) == 0 else UPat.compile(u.vin[0], name or u.arg) - return UPat(u.uop, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.vin]) if u.vin != () else None, name, u.dtype) + if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg) + return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype) T = TypeVar("T") def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: @@ -117,15 +117,15 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool: if pat.name is not None: store[pat.name] = uop if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False - if pat.uop is not None and __unmatch(pat.uop, uop.uop): return False - if pat.vin is None: return True + if pat.op is not None and __unmatch(pat.op, uop.op): return False + if pat.src is None: return True # only one if it's a tuple # try all permutations if it's a list # repeat if it's a UPat - for vp in itertools.permutations(pat.vin) if isinstance(pat.vin,list) else ([pat.vin] if isinstance(pat.vin,tuple) else [(pat.vin,)*len(uop.vin)]): - if len(uop.vin) != len(vp) and (len(uop.vin) not in pat.allow_len): return False + for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]): + if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False new_store = store.copy() - if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)): + if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)): store.update(new_store) return True return False @@ -137,14 +137,14 @@ class PatternMatcher: # uop is required, arg is optional for p,fxn in self.patterns: if isinstance(p, UOp): p = UPat.compile(p) - assert p.uop is not None - if isinstance(p.uop, set): - for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn)) + assert p.op is not None + if isinstance(p.op, set): + for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn)) else: - self.pdict[(p.uop, p.arg)].append((p, fxn)) + self.pdict[(p.op, p.arg)].append((p, fxn)) def rewrite(self, uop:UOp) -> Optional[UOp]: - for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]): + for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]): store: Dict[str, UOp] = {} if _match(uop, p, store): return fxn(**store) return None @@ -152,7 +152,7 @@ class PatternMatcher: def sum_collapse(phi_input, loop, val1, val2): for v1,v2 in [(val1, val2), (val2, val1)]: if loop not in v1.parents: - loop_range = loop.vin[1]-loop.vin[0] + loop_range = loop.src[1]-loop.src[0] ret = v1*loop_range.cast(v1.dtype) return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret return None @@ -168,34 +168,34 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst): # this is symbolic 2.0 constant_folder = PatternMatcher([ # arange loop folding (early) - (UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=( - UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, - vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]), + (UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=( + UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL, + src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]), UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse), # sum collapse to mul (with possible GEP) - (UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)), - UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), - (UPat(UOps.PHI, vin=(UPat(UOps.GEP, name="phi_input", - vin=(UPat(UOps.DEFINE_ACC, vin=(UPat(UOps.RANGE, name="loop"),)),)), - UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)), + UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), + (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", + src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)), + UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), # deal with UNMUL (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), - UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), + UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), lambda c1,c2,v: v if c1.arg == c2.arg else None), - (UOp(UOps.UNMUL, vin=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero), - (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))), + (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero), + (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))), # max on special can go away (TODO: special should be variable, same thing applies) (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None), # const rules - (UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)), - (UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), + (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)), + (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed - (UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])), - (UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, vin=tuple()), UPat(name="x"))), lambda x: x), - (UPat(UOps.PHI, vin=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x), + (UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x), # a DEFINE_ACC without inputs is a const + GEP on a const is the const - (UPat(UOps.DEFINE_ACC, name="root", vin=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])), - (UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)), + (UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])), + (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)), # max -2147483648 (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x), # -(-x) -> x @@ -209,7 +209,7 @@ constant_folder = PatternMatcher([ (UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val), (UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1), # ** constant folding ** - (UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))), + (UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))), # ** self folding ** (UOp.var('x') + 0, lambda x: x), # x+0 -> x (UOp.var('x') - 0, lambda x: x), # x-0 -> x @@ -222,8 +222,8 @@ constant_folder = PatternMatcher([ (UOp.var('x') * 0, lambda x: x if isinstance(x.arg, float) and math.isnan(x.arg) else UOp.const(x.dtype, 0)), (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0 # ** load/store folding ** - (UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), - UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)), + (UPat(UOps.STORE, src=(UPat(name="buf"), UPat(name="idx"), + UPat(UOps.LOAD, src=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)), # ** two stage add/sub folding ** ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))), ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))), @@ -247,17 +247,17 @@ constant_folder = PatternMatcher([ (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))), lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)), # store float4/float2 directly (remove CAST/GEP) - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))), UOp.store), - (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))), UOp.store), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store), # CAST-PHI-GEP -> PHI-CAST - (UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), + (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))), lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))), - (UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), + (UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))), lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))), # NEG/CMPLT -> CMPLT (UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)), # cast NOOP (NOTE: it's str to deal with PtrDType) - (UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None), + (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), ]) # *** uop graph *** @@ -271,8 +271,8 @@ class UOpGraph: def __iter__(self) -> Iterator[UOp]: return iter(self.uops) def __getitem__(self, index) -> UOp: return self.uops[index] - def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR], key=lambda v: v.expr) - def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL] + def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr) + def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL] @property def uops(self): @@ -285,7 +285,7 @@ class UOpGraph: def print(self): for i,u in enumerate(self): - print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") + print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}") def graph_rewrite(self, sink, pm): # recursive rewrite @@ -305,7 +305,7 @@ class UOpGraph: recurse_cnt += 1 changed += recurse_cnt # NOTE: this changes UOp, so we have to delete caches - up.vin = tuple(rewrite(x) for x in up.vin) + up.src = tuple(rewrite(x) for x in up.src) if 'parents' in up.__dict__: delattr(up, 'parents') if 'cmp_tuple' in up.__dict__: delattr(up, 'cmp_tuple') # replace with cached nodes @@ -327,16 +327,16 @@ class UOpGraph: n = unprocessed_nodes.pop(0) if n in all_nodes: continue all_nodes[n] = None - for x in n.vin: + for x in n.src: early_in_degree[n] += 1 children[x].append(n) - unprocessed_nodes += list(n.vin) + unprocessed_nodes += list(n.src) early_queue = [x for x in all_nodes if early_in_degree[x] == 0] replace_nodes: Dict[UOp, UOp] = {} while len(early_queue): n = early_queue.pop(0) if n in replace_nodes: continue - key = (n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg) + key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg) if found:=self.nodes.get(key): replace_nodes[n] = found else: replace_nodes[n] = self.nodes[key] = UOp(*key) for x in children[n]: @@ -367,29 +367,29 @@ class UOpGraph: def add_parents(u:UOp): if u in nodes: return nodes[u] = None - for x in u.vin: + for x in u.src: add_parents(x) in_degree[u] += 1 graph[x].append(u) - if u.uop is UOps.RANGE: loops.append(u) - if u.uop is UOps.IF: ifs.append(u) - sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP)) + if u.op is UOps.RANGE: loops.append(u) + if u.op is UOps.IF: ifs.append(u) + sink = UOp(UOps.SINK, None, tuple(x for x in sink.src if x.op is not UOps.NOOP)) add_parents(sink) @functools.lru_cache(None) def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]: - if x.uop is UOps.SINK: return set() - return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.uop is not end])) + if x.op is UOps.SINK: return set() + return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.op is not end])) # scope children impact the toposort and END* insertion end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)} - scope_children = {p:get_recursive_children(p, end_for_uop[p.uop][0]) for p in (loops+ifs)[::-1]} + scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]} queue: List = [] def push(u): priority = 0 # prefer uops that are loop children for l, ss in scope_children.items(): - if l.uop is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1] + if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1] heapq.heappush(queue, (priority, u)) for u in nodes: @@ -403,20 +403,20 @@ class UOpGraph: while queue: p,x = heapq.heappop(queue) if DEBUG >= 7: print(p,x) - if x.uop is UOps.DEFINE_ACC and len(x.vin): - idx = min([self._uops.index(l) for l in x.vin]) + if x.op is UOps.DEFINE_ACC and len(x.src): + idx = min([self._uops.index(l) for l in x.src]) self._uops.insert(idx, x) else: self._uops.append(x) for u, ss in scope_children.items(): if x in ss: ss.remove(x) - if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.uop][1], None, (u,))) + if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,))) for u in graph[x]: in_degree[u] -= 1 if in_degree[u] == 0: push(u) - assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}" + assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}" self._uops = self._uops[:-1] if type_verify: self.type_verify() @@ -431,34 +431,34 @@ class UOpGraph: dont_count: Set[UOp] = set() if ignore_indexing: for u in self.uops: - if u.uop is UOps.LOAD: - dont_count = dont_count.union(u.vin[1].sparents) - if len(u.vin) > 3: dont_count = dont_count.union(u.vin[2].sparents) - elif u.uop is UOps.STORE: - dont_count = dont_count.union(u.vin[1].sparents) - if len(u.vin) > 3: dont_count = dont_count.union(u.vin[3].sparents) + if u.op is UOps.LOAD: + dont_count = dont_count.union(u.src[1].sparents) + if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents) + elif u.op is UOps.STORE: + dont_count = dont_count.union(u.src[1].sparents) + if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents) for u in self.uops: - if u.uop is UOps.RANGE: + if u.op is UOps.RANGE: mult_stack.append(mults) - mults *= uop_alu_resolve(u.vin[1]) - elif u.uop is UOps.ENDRANGE: + mults *= uop_alu_resolve(u.src[1]) + elif u.op is UOps.ENDRANGE: mults = mult_stack.pop(-1) - elif u.uop is UOps.LOAD: + elif u.op is UOps.LOAD: assert u.dtype is not None mem += u.dtype.itemsize * mults - elif u.uop is UOps.STORE: - assert u.vin[2].dtype is not None - mem += u.vin[2].dtype.itemsize * mults - elif u.uop is UOps.ALU and u not in dont_count: + elif u.op is UOps.STORE: + assert u.src[2].dtype is not None + mem += u.src[2].dtype.itemsize * mults + elif u.op is UOps.ALU and u not in dont_count: flops += mults * (2 if u.arg == TernaryOps.MULACC else 1) - elif u.uop is UOps.WMMA and u not in dont_count: + elif u.op is UOps.WMMA and u not in dont_count: assert u.arg[1] is not None flops += 2 * prod(u.arg[1]) // 32 * mults return flops, mem def type_verify(self): for u in self.uops: - uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype + uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype if uop in {UOps.CONST, UOps.DEFINE_ACC}: if uop is UOps.DEFINE_ACC: arg = arg[0] assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" diff --git a/tinygrad/engine/graph.py b/tinygrad/engine/graph.py index f9a5e5d013..c4924cfc8e 100644 --- a/tinygrad/engine/graph.py +++ b/tinygrad/engine/graph.py @@ -77,13 +77,12 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False): def _tree(luop:Union[LazyOp,UOp], cycles, cnt, prefix=""): cnt[0] += 1 - if len(src:=luop.vin if hasattr(luop,'vin')else luop.src) == 0: - return [f"━━ {prefix}{(luop.op if hasattr(luop, 'op') else luop.uop).name} {luop.arg if luop.arg else ''}"] + if len(luop.src) == 0: return [f"━━ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"] if (lid := id(luop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0: - return [f"━⬆︎ goto {cycles[id(luop)][0]}: {(luop.op if hasattr(luop,'op')else luop.uop).name}"] + return [f"━⬆︎ goto {cycles[id(luop)][0]}: {luop.op.name}"] cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1) - lines = [f"━┳ {prefix}{(luop.op if hasattr(luop,'op')else luop.uop).name} {luop.arg if luop.arg else ''}"] - childs = [_tree(c, cycles, cnt) for c in src[:]] + lines = [f"━┳ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"] + childs = [_tree(c, cycles, cnt) for c in luop.src[:]] for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] @@ -95,7 +94,7 @@ def graph_uops(uops:List[UOp]): UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} G = nx.DiGraph() for u in uops: - if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue - G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501 - for v in u.vin: G.add_edge(uops.index(v), uops.index(u)) + if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue + G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501 + for v in u.src: G.add_edge(uops.index(v), uops.index(u)) save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR') diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 95ba7dfeb3..c73f5bc720 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -131,22 +131,22 @@ class PTXRenderer(Renderer): return ret for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg + uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg if uop is UOps.IF: assert vin[0].dtype is not None kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True))) elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) elif uop is UOps.ENDRANGE: kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]), - self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int])) + self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int])) kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred)) elif uop is UOps.ENDIF: - kk(f"IF_{r[vin[0].vin[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:") + kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None assert vin[0].dtype == dtypes.int64, "store isn't int64" - assert vin[1].uop is UOps.CONST, f"store isn't const {u}" - mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' + assert vin[1].op is UOps.CONST, f"store isn't const {u}" + mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' if vin[2].dtype.count > 1: kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \ f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};") @@ -178,8 +178,8 @@ class PTXRenderer(Renderer): elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg] elif uop is UOps.LOAD: assert vin[0].dtype == dtypes.int64, "load isn't int64" - assert vin[1].uop is UOps.CONST, f"load isn't const {u}" - mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' + assert vin[1].op is UOps.CONST, f"load isn't const {u}" + mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] if(len(vin)>3): @@ -227,44 +227,44 @@ class PTXRenderer(Renderer): ptx_matcher = PatternMatcher([ (UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]), + src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]), lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)), (UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]), - vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]), + src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]), lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)), - (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)), + (UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)), (UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"), - lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), + lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)), (UPat(UOps.ALU, BinaryOps.ADD, [UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"), - lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)), + lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)), *[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"), - lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),))) + lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),))) for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half], - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))), - lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), - (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())), - lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))), - (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())), - lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), - (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))), - lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), - (UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))), - lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)), + (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))), + lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)), + (UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())), + lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))), + (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())), + lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), + (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))), + lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)), + (UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))), + lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)), # ptr_ar (load/store) - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), - UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), - lambda root, alu, const: UOp(root.uop, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64), - UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))), + lambda root, alu, const: UOp(root.op, root.dtype, + (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), + UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(UOps.CONST, name="const"))), - lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64), - UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize), - )+root.vin[2:])), - (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), + lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64), + UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize), + )+root.src[2:])), + (UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}), UPat(name="alu"))), # no const here - lambda root, alu: UOp(root.uop, root.dtype, - (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64), - UOp.const(dtypes.int64, 0))+root.vin[2:])), + lambda root, alu: UOp(root.op, root.dtype, + (alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64), + UOp.const(dtypes.int64, 0))+root.src[2:])), ]) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 029df601c3..a9733f8731 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -104,10 +104,10 @@ class CStyleLanguage(Renderer): c[prefix] += 1 return ret - child_count = Counter(v for ru in uops for v in ru.vin) + child_count = Counter(v for ru in uops for v in ru.src) for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg + uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg # these four uops don't have output dtypes if uop is UOps.IF: kk(f"if ({r[vin[0]]}) {{") @@ -118,7 +118,7 @@ class CStyleLanguage(Renderer): kk("}") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None - rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL) + rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL) kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store) else: assert dtype is not None, f"None dtype for uop {uop}" @@ -138,7 +138,7 @@ class CStyleLanguage(Renderer): kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */") r[u] = args[1] elif uop is UOps.LOAD: - val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL) + val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL) # NOTE: this relies on the load not happening if it's in the unselected branch if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") @@ -169,7 +169,7 @@ class CStyleLanguage(Renderer): elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})" elif uop is UOps.GEP: assert vin[0].dtype is not None - from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} + from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC} r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}") else: raise RuntimeError(f"failed to render {uop}") @@ -235,7 +235,7 @@ class MetalRenderer(CStyleLanguage): return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): - prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA]) + prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA]) for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{ simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); @@ -285,7 +285,7 @@ class CUDARenderer(CStyleLanguage): prefix += ["#include "] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]] # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing - for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]): + for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1] prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b); asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};" @@ -370,7 +370,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { prefix += [_make_hip_dtype(*x) for x in vec_dts] - for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper + for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32") else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) { half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } @@ -379,7 +379,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return super().render_kernel(function_name, kernel, bufs, uops, prefix) def get_kernel_modifier(self, uops:UOpGraph) -> str: - requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l") + requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l") # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index bef5282a63..eec5706242 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -78,7 +78,7 @@ class LLVMRenderer(Renderer): module = ir.Module(name=__file__) # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order) - buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}} + buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}} buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function @@ -101,7 +101,7 @@ class LLVMRenderer(Renderer): if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) for u in uops: - uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg + uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg if uop is UOps.STORE: element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype) if len(vin) > 3: @@ -115,7 +115,7 @@ class LLVMRenderer(Renderer): lvars[vin[0]].add_incoming(idx_p1, bb[-1].block) for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block) bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block) + bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block) else: assert dtype is not None, f"None dtype for uop {uop}" if uop is UOps.RANGE: @@ -147,7 +147,7 @@ class LLVMRenderer(Renderer): lvars[u] = lvars[vin[1]] # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC backward = vin[0] - while backward.uop is UOps.PHI: backward = backward.vin[0] + while backward.op is UOps.PHI: backward = backward.src[0] lvars[backward] = lvars[u] elif uop is UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 4d0b89cf5b..eb107e9518 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -191,7 +191,7 @@ class PythonRenderer(Renderer): if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores def render(self, name:str, uops:UOpGraph) -> str: - lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops] + lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops] return base64.b64encode(pickle.dumps(lops)).decode() class PythonCompiler(Compiler):