rename uop [run_process_replay] (#5031)

* rename

* fix unittests

* rename vin

* fix test

* fix type [run_process_replay]

* rm pre commit hook change
This commit is contained in:
kormann
2024-06-18 20:34:05 +02:00
committed by GitHub
parent dc942bf1f6
commit 7c3b877216
13 changed files with 280 additions and 281 deletions

View File

@@ -48,4 +48,4 @@ repos:
entry: env PYTHONPATH="." python3 -m pylint tinygrad/
language: system
always_run: true
pass_filenames: false
pass_filenames: false

View File

@@ -16,7 +16,7 @@ from tinygrad.ops import LazyOp, UnaryOps, BufferOps
from test.helpers import is_dtype_supported
def tuplize_uops(uops:List[UOp]) -> Tuple:
return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops])
device = Device[Device.DEFAULT]

View File

@@ -12,11 +12,11 @@ def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]
paths: List[List[UOp]] = []
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
for p in find_all_toposorts(graph, in_degree):
assert p[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
paths.append(path:=list(p[:-1]))
for u in path:
if u.uop is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
if u.uop is UOps.RANGE:
if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
if u.op is UOps.RANGE:
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
return paths
@@ -58,9 +58,9 @@ def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[
def recurse_paths(path:List[UOp]):
for v, d in in_degree.items():
if d != 0 or v in visited: continue
if v.uop is UOps.DEFINE_ACC and any(l not in path for l in v.vin): continue
if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue
for u in graph[v]: in_degree[u] -= 1
if v.uop is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.vin), v)
if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v)
else: path.append(v)
visited.add(v)
recurse_paths(path)

View File

@@ -38,7 +38,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
@@ -54,7 +54,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
@@ -94,8 +94,8 @@ class TestLinearizer(unittest.TestCase):
b_t = Tensor.full(st.shape, 3).contiguous().realize()
lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
stores = [u for u in lin.uops if u.uop is UOps.STORE]
mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]]
stores = [u for u in lin.uops if u.op is UOps.STORE]
mutable_bufs = [u for u in lin.uops if u.op is UOps.DEFINE_GLOBAL and u.arg[-1]]
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
@@ -108,92 +108,92 @@ class TestLinearizer(unittest.TestCase):
load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
self.assertEqual(k.uops[-1].uop, UOps.ENDIF)
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1]))
self.assertEqual(k.uops[-1].op, UOps.ENDIF)
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.op is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1]))
def test_two_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
if getenv("PTX"):
# RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI
assert ranges[1] == ranges[0]+4
assert lin.uops[ranges[0]+3].uop is UOps.LOAD
assert lin.uops[ranges[0]+3].op is UOps.LOAD
else:
# RANGE -> LOAD -> RANGE -> PHI
assert ranges[1] == ranges[0]+2
assert lin.uops[ranges[0]+1].uop is UOps.LOAD
assert lin.uops[ranges[0]+1].op is UOps.LOAD
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
if getenv("PTX"):
# RANGE -> RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI
assert ranges[2] == ranges[1]+4 == ranges[0]+5
assert lin.uops[ranges[1]+3].uop is UOps.LOAD
assert lin.uops[ranges[1]+3].op is UOps.LOAD
else:
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
assert ranges[2] == ranges[1]+2 == ranges[0]+3
assert lin.uops[ranges[1]+1].uop is UOps.LOAD
assert lin.uops[ranges[1]+1].op is UOps.LOAD
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
if getenv("PTX"):
# RANGE -> CAST ridx -> LOAD_INDEXING -> 4x ALU -> RANGE -> LOAD -> RANGE -> PHI
assert ranges[1] == ranges[0]+6
assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE
assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE
else:
# RANGE -> 4x ALU -> RANGE -> 9x ALU + 1x LOAD -> PHI
assert ranges[1] == ranges[0]+5
assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE
assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()[0]])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> PHI
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
assert lin.uops[ranges[0]-2].op is UOps.LOAD
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()[0]])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
if getenv("PTX"):
# LOAD -> RANGE -> 3xLOAD_INDEXING -> LOAD -> ALU -> RANGE -> PHI
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
assert lin.uops[ranges[0]-2].op is UOps.LOAD
assert ranges[1] == ranges[0]+5
assert [x.uop for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU]
assert [x.op for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
else:
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
assert lin.uops[ranges[0]-2].op is UOps.LOAD
assert ranges[1] == ranges[0]+3
assert [x.uop for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU]
assert [x.op for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU]
def test_range_outer_op_after_phi(self):
a = Tensor.randn(4, 1).realize()
out = a.sum() * a.sum()
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
assert lin.uops[end+1].uop is UOps.ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
def test_range_outer_op_after_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
assert lin.uops[end+1].uop is UOps.ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -202,11 +202,11 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*ast)
k.hand_coded_optimizations()
k.linearize()
self.assertEqual(len(endifs:=[x for x in k.uops if x.uop is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.uop is UOps.IF]))
self.assertEqual(len(barriers:=[x for x in k.uops if x.uop is UOps.BARRIER]), 3)
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].uop, UOps.STORE)
self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
self.assertEqual(len(barriers:=[x for x in k.uops if x.op is UOps.BARRIER]), 3)
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].op, UOps.STORE)
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+1], barriers[1])
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].uop, UOps.LOAD)
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].op, UOps.LOAD)
self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0]))
self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0]))
self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1]))
@@ -245,13 +245,13 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*ast)
k.hand_coded_optimizations()
k.linearize()
local_buf = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL]
self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.uop is UOps.STORE and any([lb in u.vin for lb in local_buf])]), 3, \
local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.op is UOps.STORE and any([lb in u.src for lb in local_buf])]), 3, \
f"should have generated 3 BufferOps.STORE to the local buf but got {len(real_local_stores)}")
self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.uop is UOps.LOAD and any([lb in u.vin for lb in local_buf])]), 3, \
self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.op is UOps.LOAD and any([lb in u.src for lb in local_buf])]), 3, \
f"should have generated 3 BufferOps.LOAD to the local buf but got {len(real_local_loads)}")
self.assertEqual((real_local_stores[1].vin[1].uop, real_local_stores[1].vin[1].arg), (UOps.CONST, 0))
self.assertEqual((real_local_loads[1].vin[1].uop, real_local_loads[1].vin[1].arg), (UOps.CONST, 0))
self.assertEqual((real_local_stores[1].src[1].op, real_local_stores[1].src[1].arg), (UOps.CONST, 0))
self.assertEqual((real_local_loads[1].src[1].op, real_local_loads[1].src[1].arg), (UOps.CONST, 0))
x = Tensor.randn(3,27,32).realize()
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)])
@@ -261,9 +261,9 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*ast)
k.upcast()
k.linearize()
define_globals = [u for u in k.uops if u.uop is UOps.DEFINE_GLOBAL]
self.assertEqual(len([u for u in k.uops if u.uop is UOps.LOAD and define_globals[1] in u.vin]), 7)
self.assertEqual(len([u for u in k.uops if u.uop is UOps.ALU and u.arg is BinaryOps.ADD]), 25)
define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
self.assertEqual(len([u for u in k.uops if u.op is UOps.LOAD and define_globals[1] in u.src]), 7)
self.assertEqual(len([u for u in k.uops if u.op is UOps.ALU and u.arg is BinaryOps.ADD]), 25)
opts = [[Opt(op=OptOps.UPCAST, axis=0, amt=2)], [Opt(op=OptOps.UPCAST, axis=0, amt=4)]]
x = Tensor.randn(8,7).softmax().realize()
helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)])
@@ -288,12 +288,12 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*ast)
k.hand_coded_optimizations()
k.linearize()
def get_recursive_children(x:UOp): return set.union(set(x.vin), *[get_recursive_children(v) for v in x.vin])
def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
loop = None
for u in k.uops:
if u.uop is UOps.RANGE: loop = u
if u.op is UOps.RANGE: loop = u
elif loop is None: continue
elif u.uop is UOps.ENDRANGE and loop in u.vin: loop = None
elif u.op is UOps.ENDRANGE and loop in u.src: loop = None
else: self.assertIn(loop, get_recursive_children(u), f"Any uop within a loop should depend on the loop: {u}")
x = Tensor.randn(3, 27, 32).realize()
helper_linearizer_ast(ast, [x], wanna_output= \
@@ -358,7 +358,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD])
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@@ -377,7 +377,7 @@ class TestLinearizer(unittest.TestCase):
lin.linearize()
assert len(lin.uops.uops) <= 7, "too many uops"
a_bufs = [u.uop for u in lin.uops.uops[-1].vin[2].vin]
a_bufs = [u.op for u in lin.uops.uops[-1].src[2].src]
assert a_bufs == [UOps.LOAD, UOps.CONST]
def test_upcast_cse(self):
@@ -389,7 +389,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
assert num_ops <= 1, "more alu uops than needed"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -401,11 +401,11 @@ class TestLinearizer(unittest.TestCase):
k.upcast()
k.upcast()
k.linearize()
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop is UOps.STORE]
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.op is UOps.STORE]
assert len(accs) == 0 # it's removed now
assert len(stores) == 1
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -417,15 +417,15 @@ class TestLinearizer(unittest.TestCase):
k.hand_coded_optimizations()
k.linearize()
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop is UOps.STORE]
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.op is UOps.STORE]
# the first store is to lds and can be upcasted
assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL
assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
# the second store is to gds with no upcasts
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
assert accs[1].dtype == stores[1].src[-1].dtype == dtypes.float
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
def test_upcast_multireduce_nested_local_upcast(self):
@@ -449,7 +449,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
assert num_ops == 0, "more alu uops than needed"
def test_sum_acc_dtype(self):
@@ -458,14 +458,14 @@ class TestLinearizer(unittest.TestCase):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
tests = (
@@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
k.linearize()
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg = CompiledRunner(k.to_program())
@@ -571,8 +571,8 @@ class TestLinearizer(unittest.TestCase):
r = x.matmul(y, acc_dtype=tc.dtype_out)
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.uop is UOps.WMMA:
assert u.vin[-1].vin[0].uop != UOps.PHI
if u.op is UOps.WMMA:
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_casted_phi(self):
@@ -581,9 +581,9 @@ class TestLinearizer(unittest.TestCase):
r = x.matmul(y, acc_dtype=tc.dtype_out)
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.uop is UOps.WMMA:
assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.vin[-1].vin[0].uop != UOps.PHI
if u.op is UOps.WMMA:
assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_casted_phi_with_children(self):
@@ -593,9 +593,9 @@ class TestLinearizer(unittest.TestCase):
r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.uop is UOps.WMMA:
assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.vin[-1].vin[0].uop != UOps.PHI
if u.op is UOps.WMMA:
assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
@@ -604,11 +604,11 @@ class TestLinearizer(unittest.TestCase):
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE
for u in k.uops:
if u.uop is UOps.PHI:
assert u.vin[1].uop is UOps.ALU
if u.op is UOps.PHI:
assert u.src[1].op is UOps.ALU
# children of PHI are placed after ENDRANGE
if any(x.uop is UOps.PHI for x in u.vin):
end_range = [i for i, x in enumerate(k.uops) if x.uop is UOps.ENDRANGE][0]
if any(x.op is UOps.PHI for x in u.src):
end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
assert end_range < k.uops.uops.index(u)
def test_grouped_dims(self):
@@ -649,7 +649,7 @@ class TestLinearizer(unittest.TestCase):
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(*sched[0].ast)
assert not any(u.uop is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
@@ -680,10 +680,10 @@ class TestLinearizer(unittest.TestCase):
k.hand_coded_optimizations()
uops = list(k.linearize().uops)
# ignore kernel optimized IF statements for now
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
if if_op:=next((u for u in uops if u.op is UOps.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.uop for u in uops if u.uop in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
@@ -708,17 +708,17 @@ class TestLinearizer(unittest.TestCase):
out = x.matmul(y)
k = helper_linearizer_opt(out)[-1]
# check that the float4 cast collapses
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_grouped_store_values(self):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
k = helper_linearizer_opt(out)[-1]
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -729,16 +729,16 @@ class TestLinearizer(unittest.TestCase):
opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
k = helper_linearizer_opt(out, opts=[opt])[-1]
def get_recursive(uop): return set.union(set(uop.vin), [uop], *[get_recursive(v) for v in uop.vin])
local_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_LOCAL for x in get_recursive(u.vin[0]))]
global_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_GLOBAL for x in get_recursive(u.vin[0]))]
barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0]
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST
assert store.src[-1].dtype == dtypes.float.vec(2) and store.src[-1].op is not UOps.CAST
# check the children's vins
assert barrier.vin == tuple(local_stores)
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
assert barrier.src == tuple(local_stores)
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -747,14 +747,14 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = helper_linearizer_opt(r)[-1]
stores = [u for u in k.uops if u.uop is UOps.STORE]
stores = [u for u in k.uops if u.op is UOps.STORE]
# the float4 value stores directly in lds and we skip upcast
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[-1].uop is not UOps.CAST
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert stores[0].src[-1].op is not UOps.CAST
# the global store doesn't change
assert stores[1].vin[-1].dtype == dtypes.float
assert stores[1].src[-1].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
@@ -764,8 +764,8 @@ class TestLinearizer(unittest.TestCase):
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
]
k = helper_linearizer_ast(ast, [Tensor.empty(240*40).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.uop is UOps.STORE][0]
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -775,15 +775,15 @@ class TestLinearizer(unittest.TestCase):
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
k = helper_linearizer_ast(ast, [Tensor.empty(8*32).realize()], opts=[opt])[-1]
out = [u for u in k.uops if u.uop is UOps.STORE][0]
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2)
out = [u for u in k.uops if u.op is UOps.STORE][0]
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k):
return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)]))
# TODO: express opts below as auto opts
@@ -1178,10 +1178,10 @@ class TestKernelOpts(unittest.TestCase):
for k in lins:
seen_bar = False
for u in k.uops:
if u.uop is UOps.BARRIER:
if u.op is UOps.BARRIER:
assert not seen_bar, "redudant barrier"
seen_bar = True
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
@unittest.skip("TODO: broken")
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
@@ -1200,10 +1200,10 @@ class TestKernelOpts(unittest.TestCase):
for k in lins:
seen_bar = False
for u in k.uops:
if u.uop is UOps.BARRIER:
if u.op is UOps.BARRIER:
assert not seen_bar, "redudant barrier"
seen_bar = True
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -1225,10 +1225,10 @@ class TestKernelOpts(unittest.TestCase):
for k in lins:
seen_bar = False
for u in k.uops:
if u.uop is UOps.BARRIER:
if u.op is UOps.BARRIER:
assert not seen_bar, "redudant barrier"
seen_bar = True
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
def test_upcasts(self):
N = 16

View File

@@ -6,7 +6,7 @@ from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
class TestPatternMatcher(unittest.TestCase):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
self.assertEqual(uop1.uop, uop2.uop)
self.assertEqual(uop1.op, uop2.op)
self.assertEqual(uop1.dtype, uop2.dtype)
self.assertEqual(uop1.arg, uop2.arg)
@@ -64,7 +64,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c3), c3)
def test_dup_name(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
@@ -91,13 +91,13 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_vin_one(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c2), None)
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), None)
def test_vin_permutations(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -118,7 +118,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c6), None)
def test_vin_repeat(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=UPat(UOps.CONST)), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
@@ -127,7 +127,7 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
@@ -157,7 +157,7 @@ class TestPatternMatcher(unittest.TestCase):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
matcher.rewrite_graph(uops)
uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE}))
uops.remove_childless(set(x for x in uops if x.op in {UOps.STORE}))
self.assertEqual(len(uops.uops), 3)

View File

@@ -12,7 +12,7 @@ class TestUOpGraph(unittest.TestCase):
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
@@ -24,7 +24,7 @@ class TestUOpGraph(unittest.TestCase):
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 1.0)
def test_where_const_fold(self):
@@ -35,7 +35,7 @@ class TestUOpGraph(unittest.TestCase):
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 2.0)
def test_const_cast(self):
@@ -44,7 +44,7 @@ class TestUOpGraph(unittest.TestCase):
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.CONST)
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0)
def test_cast_vectorized_fold(self):
@@ -56,7 +56,7 @@ class TestUOpGraph(unittest.TestCase):
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
g = UOpGraph([out])
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
@@ -67,10 +67,10 @@ class TestUOpGraph(unittest.TestCase):
g = UOpGraph([out])
self.assertEqual(len(g.uops), 3)
out = g.uops[-1]
self.assertEqual(out.uop, UOps.ALU)
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
self.assertEqual(out.vin[1].uop, UOps.CONST)
self.assertEqual(out.vin[1].arg, 6)
self.assertEqual(out.src[1].op, UOps.CONST)
self.assertEqual(out.src[1].arg, 6)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -226,7 +226,7 @@ class TestConstantFolding(unittest.TestCase):
si = create_schedule([t.lazydata])
assert len(si) == 1
ji = lower_schedule_item(si[-1])
assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast"
assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
class TestGatedStoreRewrite(unittest.TestCase):
@unittest.skip("not yet implemented")
@@ -240,9 +240,9 @@ class TestGatedStoreRewrite(unittest.TestCase):
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
assert endif.vin[0] is if_uop
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
assert endif.src[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem, gidx0, idx, value)
@@ -261,9 +261,9 @@ class TestGatedStoreRewrite(unittest.TestCase):
outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
uops = UOpGraph(outs)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.uop is UOps.IF)
endif = next(u for u in uops if u.uop is UOps.ENDIF)
assert endif.vin[0] is if_uop
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
assert endif.src[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem0, value0)

View File

@@ -29,23 +29,23 @@ class UOps(Enum):
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
@dataclass(eq=False)
class UOp:
uop: UOps
op: UOps
dtype: Optional[DType] = None
vin: Tuple[UOp, ...] = tuple()
src: Tuple[UOp, ...] = tuple()
arg: Any = None
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
def tuple(self): return (self.op, self.dtype, self.src, self.arg)
def commutative(self) -> bool:
return self.uop is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
@functools.cached_property
def cmp_tuple(self):
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
(type(self.uop), self.uop.value), self.dtype, self.vin)
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
(type(self.op), self.op.value), self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
def name(self, name:Optional[str]): return UOp(UOps.VAR, vin=(self,), arg=name)
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
@@ -75,35 +75,35 @@ class UOp:
@staticmethod
def cvar(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
@functools.cached_property
def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
@property # parents with self
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.uop is UOps.DEFINE_VAR])
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
def uop_alu_resolve(u:UOp) -> sint:
if u.uop is UOps.CONST: return u.arg
if u.uop is UOps.DEFINE_VAR: return u.arg
if u.uop is UOps.SPECIAL: return u.arg[2]-1
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
if u.uop is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.vin[0]) * (2**cast(int, uop_alu_resolve(u.vin[1])))
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return u.arg
if u.op is UOps.SPECIAL: return u.arg[2]-1
if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
raise RuntimeError(f"ALU resolve fail @ {u.op}")
# *** simplification logic ***
@dataclass(frozen=True)
class UPat:
uop: Optional[Union[UOps, Set[UOps]]] = None
op: Optional[Union[UOps, Set[UOps]]] = None
arg: Any = None
vin: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
name: Optional[str] = None
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
@staticmethod
def compile(u: UOp, name:Optional[str]=None) -> UPat:
if u.uop is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.vin) == 0 else UPat.compile(u.vin[0], name or u.arg)
return UPat(u.uop, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.vin]) if u.vin != () else None, name, u.dtype)
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype)
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
@@ -117,15 +117,15 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name is not None: store[pat.name] = uop
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
if pat.uop is not None and __unmatch(pat.uop, uop.uop): return False
if pat.vin is None: return True
if pat.op is not None and __unmatch(pat.op, uop.op): return False
if pat.src is None: return True
# only one if it's a tuple
# try all permutations if it's a list
# repeat if it's a UPat
for vp in itertools.permutations(pat.vin) if isinstance(pat.vin,list) else ([pat.vin] if isinstance(pat.vin,tuple) else [(pat.vin,)*len(uop.vin)]):
if len(uop.vin) != len(vp) and (len(uop.vin) not in pat.allow_len): return False
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
store.update(new_store)
return True
return False
@@ -137,14 +137,14 @@ class PatternMatcher:
# uop is required, arg is optional
for p,fxn in self.patterns:
if isinstance(p, UOp): p = UPat.compile(p)
assert p.uop is not None
if isinstance(p.uop, set):
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
assert p.op is not None
if isinstance(p.op, set):
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
else:
self.pdict[(p.uop, p.arg)].append((p, fxn))
self.pdict[(p.op, p.arg)].append((p, fxn))
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
store: Dict[str, UOp] = {}
if _match(uop, p, store): return fxn(**store)
return None
@@ -152,7 +152,7 @@ class PatternMatcher:
def sum_collapse(phi_input, loop, val1, val2):
for v1,v2 in [(val1, val2), (val2, val1)]:
if loop not in v1.parents:
loop_range = loop.vin[1]-loop.vin[0]
loop_range = loop.src[1]-loop.src[0]
ret = v1*loop_range.cast(v1.dtype)
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
return None
@@ -168,34 +168,34 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# arange loop folding (early)
(UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=(
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
# sum collapse to mul (with possible GEP)
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)),
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, vin=(UPat(UOps.GEP, name="phi_input",
vin=(UPat(UOps.DEFINE_ACC, vin=(UPat(UOps.RANGE, name="loop"),)),)),
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input",
src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
# deal with UNMUL
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"),
UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
lambda c1,c2,v: v if c1.arg == c2.arg else None),
(UOp(UOps.UNMUL, vin=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
# max on special can go away (TODO: special should be variable, same thing applies)
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
# const rules
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, vin=tuple()), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, vin=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
(UPat(UOps.DEFINE_ACC, name="root", vin=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
(UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
# -(-x) -> x
@@ -209,7 +209,7 @@ constant_folder = PatternMatcher([
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
# ** self folding **
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
@@ -222,8 +222,8 @@ constant_folder = PatternMatcher([
(UOp.var('x') * 0, lambda x: x if isinstance(x.arg, float) and math.isnan(x.arg) else UOp.const(x.dtype, 0)),
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
# ** load/store folding **
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
(UPat(UOps.STORE, src=(UPat(name="buf"), UPat(name="idx"),
UPat(UOps.LOAD, src=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
# ** two stage add/sub folding **
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
@@ -247,17 +247,17 @@ constant_folder = PatternMatcher([
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# store float4/float2 directly (remove CAST/GEP)
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
# CAST-PHI-GEP -> PHI-CAST
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
])
# *** uop graph ***
@@ -271,8 +271,8 @@ class UOpGraph:
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
def __getitem__(self, index) -> UOp: return self.uops[index]
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR], key=lambda v: v.expr)
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
@property
def uops(self):
@@ -285,7 +285,7 @@ class UOpGraph:
def print(self):
for i,u in enumerate(self):
print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
def graph_rewrite(self, sink, pm):
# recursive rewrite
@@ -305,7 +305,7 @@ class UOpGraph:
recurse_cnt += 1
changed += recurse_cnt
# NOTE: this changes UOp, so we have to delete caches
up.vin = tuple(rewrite(x) for x in up.vin)
up.src = tuple(rewrite(x) for x in up.src)
if 'parents' in up.__dict__: delattr(up, 'parents')
if 'cmp_tuple' in up.__dict__: delattr(up, 'cmp_tuple')
# replace with cached nodes
@@ -327,16 +327,16 @@ class UOpGraph:
n = unprocessed_nodes.pop(0)
if n in all_nodes: continue
all_nodes[n] = None
for x in n.vin:
for x in n.src:
early_in_degree[n] += 1
children[x].append(n)
unprocessed_nodes += list(n.vin)
unprocessed_nodes += list(n.src)
early_queue = [x for x in all_nodes if early_in_degree[x] == 0]
replace_nodes: Dict[UOp, UOp] = {}
while len(early_queue):
n = early_queue.pop(0)
if n in replace_nodes: continue
key = (n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
if found:=self.nodes.get(key): replace_nodes[n] = found
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
for x in children[n]:
@@ -367,29 +367,29 @@ class UOpGraph:
def add_parents(u:UOp):
if u in nodes: return
nodes[u] = None
for x in u.vin:
for x in u.src:
add_parents(x)
in_degree[u] += 1
graph[x].append(u)
if u.uop is UOps.RANGE: loops.append(u)
if u.uop is UOps.IF: ifs.append(u)
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
if u.op is UOps.RANGE: loops.append(u)
if u.op is UOps.IF: ifs.append(u)
sink = UOp(UOps.SINK, None, tuple(x for x in sink.src if x.op is not UOps.NOOP))
add_parents(sink)
@functools.lru_cache(None)
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
if x.uop is UOps.SINK: return set()
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.uop is not end]))
if x.op is UOps.SINK: return set()
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.op is not end]))
# scope children impact the toposort and END* insertion
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
scope_children = {p:get_recursive_children(p, end_for_uop[p.uop][0]) for p in (loops+ifs)[::-1]}
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
queue: List = []
def push(u):
priority = 0
# prefer uops that are loop children
for l, ss in scope_children.items():
if l.uop is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
heapq.heappush(queue, (priority, u))
for u in nodes:
@@ -403,20 +403,20 @@ class UOpGraph:
while queue:
p,x = heapq.heappop(queue)
if DEBUG >= 7: print(p,x)
if x.uop is UOps.DEFINE_ACC and len(x.vin):
idx = min([self._uops.index(l) for l in x.vin])
if x.op is UOps.DEFINE_ACC and len(x.src):
idx = min([self._uops.index(l) for l in x.src])
self._uops.insert(idx, x)
else:
self._uops.append(x)
for u, ss in scope_children.items():
if x in ss:
ss.remove(x)
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.uop][1], None, (u,)))
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
for u in graph[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]
if type_verify: self.type_verify()
@@ -431,34 +431,34 @@ class UOpGraph:
dont_count: Set[UOp] = set()
if ignore_indexing:
for u in self.uops:
if u.uop is UOps.LOAD:
dont_count = dont_count.union(u.vin[1].sparents)
if len(u.vin) > 3: dont_count = dont_count.union(u.vin[2].sparents)
elif u.uop is UOps.STORE:
dont_count = dont_count.union(u.vin[1].sparents)
if len(u.vin) > 3: dont_count = dont_count.union(u.vin[3].sparents)
if u.op is UOps.LOAD:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
for u in self.uops:
if u.uop is UOps.RANGE:
if u.op is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.vin[1])
elif u.uop is UOps.ENDRANGE:
mults *= uop_alu_resolve(u.src[1])
elif u.op is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.uop is UOps.LOAD:
elif u.op is UOps.LOAD:
assert u.dtype is not None
mem += u.dtype.itemsize * mults
elif u.uop is UOps.STORE:
assert u.vin[2].dtype is not None
mem += u.vin[2].dtype.itemsize * mults
elif u.uop is UOps.ALU and u not in dont_count:
elif u.op is UOps.STORE:
assert u.src[2].dtype is not None
mem += u.src[2].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
elif u.uop is UOps.WMMA and u not in dont_count:
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem
def type_verify(self):
for u in self.uops:
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.DEFINE_ACC: arg = arg[0]
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"

View File

@@ -77,13 +77,12 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
def _tree(luop:Union[LazyOp,UOp], cycles, cnt, prefix=""):
cnt[0] += 1
if len(src:=luop.vin if hasattr(luop,'vin')else luop.src) == 0:
return [f"━━ {prefix}{(luop.op if hasattr(luop, 'op') else luop.uop).name} {luop.arg if luop.arg else ''}"]
if len(luop.src) == 0: return [f"━━ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
if (lid := id(luop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {(luop.op if hasattr(luop,'op')else luop.uop).name}"]
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {luop.op.name}"]
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
lines = [f"━┳ {prefix}{(luop.op if hasattr(luop,'op')else luop.uop).name} {luop.arg if luop.arg else ''}"]
childs = [_tree(c, cycles, cnt) for c in src[:]]
lines = [f"━┳ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
childs = [_tree(c, cycles, cnt) for c in luop.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
@@ -95,7 +94,7 @@ def graph_uops(uops:List[UOp]):
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
G = nx.DiGraph()
for u in uops:
if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')

View File

@@ -131,22 +131,22 @@ class PTXRenderer(Renderer):
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDRANGE:
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
elif uop is UOps.ENDIF:
kk(f"IF_{r[vin[0].vin[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
assert vin[0].dtype == dtypes.int64, "store isn't int64"
assert vin[1].uop is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
assert vin[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
if vin[2].dtype.count > 1:
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
@@ -178,8 +178,8 @@ class PTXRenderer(Renderer):
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
elif uop is UOps.LOAD:
assert vin[0].dtype == dtypes.int64, "load isn't int64"
assert vin[1].uop is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
assert vin[1].op is UOps.CONST, f"load isn't const {u}"
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
@@ -227,44 +227,44 @@ class PTXRenderer(Renderer):
ptx_matcher = PatternMatcher([
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
(UPat(UOps.ALU, BinaryOps.ADD,
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())),
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
# ptr_ar (load/store)
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
lambda root, alu, const: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
lambda root, alu, const: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(UOps.CONST, name="const"))),
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
)+root.vin[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
)+root.src[2:])),
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
UPat(name="alu"))), # no const here
lambda root, alu: UOp(root.uop, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.vin[2:])),
lambda root, alu: UOp(root.op, root.dtype,
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
UOp.const(dtypes.int64, 0))+root.src[2:])),
])

View File

@@ -104,10 +104,10 @@ class CStyleLanguage(Renderer):
c[prefix] += 1
return ret
child_count = Counter(v for ru in uops for v in ru.vin)
child_count = Counter(v for ru in uops for v in ru.src)
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[vin[0]]}) {{")
@@ -118,7 +118,7 @@ class CStyleLanguage(Renderer):
kk("}")
elif uop is UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
@@ -138,7 +138,7 @@ class CStyleLanguage(Renderer):
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
@@ -169,7 +169,7 @@ class CStyleLanguage(Renderer):
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
else: raise RuntimeError(f"failed to render {uop}")
@@ -235,7 +235,7 @@ class MetalRenderer(CStyleLanguage):
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
@@ -285,7 +285,7 @@ class CUDARenderer(CStyleLanguage):
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
@@ -370,7 +370,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
prefix += [_make_hip_dtype(*x) for x in vec_dts]
for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
@@ -379,7 +379,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def get_kernel_modifier(self, uops:UOpGraph) -> str:
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"

View File

@@ -78,7 +78,7 @@ class LLVMRenderer(Renderer):
module = ir.Module(name=__file__)
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
@@ -101,7 +101,7 @@ class LLVMRenderer(Renderer):
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
if len(vin) > 3:
@@ -115,7 +115,7 @@ class LLVMRenderer(Renderer):
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:
@@ -147,7 +147,7 @@ class LLVMRenderer(Renderer):
lvars[u] = lvars[vin[1]]
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
backward = vin[0]
while backward.uop is UOps.PHI: backward = backward.vin[0]
while backward.op is UOps.PHI: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype)

View File

@@ -191,7 +191,7 @@ class PythonRenderer(Renderer):
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
def render(self, name:str, uops:UOpGraph) -> str:
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
class PythonCompiler(Compiler):