mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
rename uop [run_process_replay] (#5031)
* rename * fix unittests * rename vin * fix test * fix type [run_process_replay] * rm pre commit hook change
This commit is contained in:
@@ -48,4 +48,4 @@ repos:
|
||||
entry: env PYTHONPATH="." python3 -m pylint tinygrad/
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
pass_filenames: false
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -16,7 +16,7 @@ from tinygrad.ops import LazyOp, UnaryOps, BufferOps
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple:
|
||||
return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
||||
return tuple([(x.op, x.dtype, tuple(uops.index(x) for x in x.src), x.arg) for x in uops])
|
||||
|
||||
device = Device[Device.DEFAULT]
|
||||
|
||||
|
||||
10
test/external/fuzz_uops.py
vendored
10
test/external/fuzz_uops.py
vendored
@@ -12,11 +12,11 @@ def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]
|
||||
paths: List[List[UOp]] = []
|
||||
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
|
||||
for p in find_all_toposorts(graph, in_degree):
|
||||
assert p[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
|
||||
assert p[-1].op is UOps.SINK, f"didn't end with SINK, ended with {p[-1]}"
|
||||
paths.append(path:=list(p[:-1]))
|
||||
for u in path:
|
||||
if u.uop is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
|
||||
if u.uop is UOps.RANGE:
|
||||
if u.op is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
|
||||
if u.op is UOps.RANGE:
|
||||
path.insert(max(path.index(x) for x in loops_children[u] if x in path)+1, UOp(UOps.ENDRANGE, None, (u,)))
|
||||
return paths
|
||||
|
||||
@@ -58,9 +58,9 @@ def find_all_toposorts(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[
|
||||
def recurse_paths(path:List[UOp]):
|
||||
for v, d in in_degree.items():
|
||||
if d != 0 or v in visited: continue
|
||||
if v.uop is UOps.DEFINE_ACC and any(l not in path for l in v.vin): continue
|
||||
if v.op is UOps.DEFINE_ACC and any(l not in path for l in v.src): continue
|
||||
for u in graph[v]: in_degree[u] -= 1
|
||||
if v.uop is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.vin), v)
|
||||
if v.op is UOps.DEFINE_ACC: path.insert(min(path.index(l) for l in v.src), v)
|
||||
else: path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
|
||||
@@ -38,7 +38,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
np_c = np_a @ np_b
|
||||
if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3
|
||||
@@ -54,7 +54,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
|
||||
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
|
||||
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
|
||||
if ensure_triggered:
|
||||
assert wmmas > 0, "tensor core not triggered"
|
||||
@@ -94,8 +94,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
b_t = Tensor.full(st.shape, 3).contiguous().realize()
|
||||
lin = helper_linearizer_ast((out0, out1), [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
|
||||
|
||||
stores = [u for u in lin.uops if u.uop is UOps.STORE]
|
||||
mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]]
|
||||
stores = [u for u in lin.uops if u.op is UOps.STORE]
|
||||
mutable_bufs = [u for u in lin.uops if u.op is UOps.DEFINE_GLOBAL and u.arg[-1]]
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
|
||||
|
||||
@@ -108,92 +108,92 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
load_t = Tensor.full(load.st.shape, 1).contiguous().realize()
|
||||
k = helper_linearizer_ast(ast, [load_t], wanna_output=[load_t.numpy().sum()])[1]
|
||||
self.assertEqual(k.uops[-1].uop, UOps.ENDIF)
|
||||
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1]))
|
||||
self.assertEqual(k.uops[-1].op, UOps.ENDIF)
|
||||
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.op is UOps.STORE][-1]), k.uops.uops.index(k.uops[-1]))
|
||||
|
||||
def test_two_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
# RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI
|
||||
assert ranges[1] == ranges[0]+4
|
||||
assert lin.uops[ranges[0]+3].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[0]+3].op is UOps.LOAD
|
||||
else:
|
||||
# RANGE -> LOAD -> RANGE -> PHI
|
||||
assert ranges[1] == ranges[0]+2
|
||||
assert lin.uops[ranges[0]+1].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[0]+1].op is UOps.LOAD
|
||||
|
||||
def test_three_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
# RANGE -> RANGE -> 2xLOAD_INDEXING -> LOAD -> RANGE -> PHI
|
||||
assert ranges[2] == ranges[1]+4 == ranges[0]+5
|
||||
assert lin.uops[ranges[1]+3].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[1]+3].op is UOps.LOAD
|
||||
else:
|
||||
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
|
||||
assert ranges[2] == ranges[1]+2 == ranges[0]+3
|
||||
assert lin.uops[ranges[1]+1].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[1]+1].op is UOps.LOAD
|
||||
|
||||
def test_two_nested_range_alt_indexing(self):
|
||||
a = Tensor([2, 2]).realize()
|
||||
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
# RANGE -> CAST ridx -> LOAD_INDEXING -> 4x ALU -> RANGE -> LOAD -> RANGE -> PHI
|
||||
assert ranges[1] == ranges[0]+6
|
||||
assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE
|
||||
assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE
|
||||
else:
|
||||
# RANGE -> 4x ALU -> RANGE -> 9x ALU + 1x LOAD -> PHI
|
||||
assert ranges[1] == ranges[0]+5
|
||||
assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE
|
||||
assert lin.uops[ranges[1]+11].op is UOps.ENDRANGE
|
||||
|
||||
def test_range_outer_op_before_phi(self):
|
||||
a = Tensor.randn(4, 1).realize()
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a + b[0]).sum() + b[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()[0]])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
# LOAD -> RANGE -> LOAD -> PHI
|
||||
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
|
||||
def test_range_outer_op_before_phi_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()[0]])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
if getenv("PTX"):
|
||||
# LOAD -> RANGE -> 3xLOAD_INDEXING -> LOAD -> ALU -> RANGE -> PHI
|
||||
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
assert ranges[1] == ranges[0]+5
|
||||
assert [x.uop for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU]
|
||||
assert [x.op for x in lin.uops[ranges[0]+3:ranges[0]+5]] == [UOps.LOAD, UOps.ALU]
|
||||
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
|
||||
else:
|
||||
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
assert ranges[1] == ranges[0]+3
|
||||
assert [x.uop for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU]
|
||||
assert [x.op for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU]
|
||||
|
||||
def test_range_outer_op_after_phi(self):
|
||||
a = Tensor.randn(4, 1).realize()
|
||||
out = a.sum() * a.sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
|
||||
# RANGE -> LOAD -> PHI -> ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
|
||||
assert lin.uops[end+1].uop is UOps.ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
|
||||
assert lin.uops[end+1].op is UOps.ALU
|
||||
|
||||
def test_range_outer_op_after_phi_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
|
||||
# RANGE -> LOAD -> PHI -> ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
|
||||
assert lin.uops[end+1].uop is UOps.ALU
|
||||
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
|
||||
assert lin.uops[end+1].op is UOps.ALU
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -202,11 +202,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
self.assertEqual(len(endifs:=[x for x in k.uops if x.uop is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.uop is UOps.IF]))
|
||||
self.assertEqual(len(barriers:=[x for x in k.uops if x.uop is UOps.BARRIER]), 3)
|
||||
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].uop, UOps.STORE)
|
||||
self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
|
||||
self.assertEqual(len(barriers:=[x for x in k.uops if x.op is UOps.BARRIER]), 3)
|
||||
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])-1].op, UOps.STORE)
|
||||
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+1], barriers[1])
|
||||
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].uop, UOps.LOAD)
|
||||
self.assertEqual(k.uops[k.uops.uops.index(endifs[0])+2].op, UOps.LOAD)
|
||||
self.assertLess(k.uops.uops.index(barriers[0]), k.uops.uops.index(ifs[0]))
|
||||
self.assertLess(k.uops.uops.index(ifs[0]), k.uops.uops.index(endifs[0]))
|
||||
self.assertLess(k.uops.uops.index(barriers[1]), k.uops.uops.index(ifs[1]))
|
||||
@@ -245,13 +245,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
local_buf = [u for u in k.uops if u.uop is UOps.DEFINE_LOCAL]
|
||||
self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.uop is UOps.STORE and any([lb in u.vin for lb in local_buf])]), 3, \
|
||||
local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
|
||||
self.assertEqual(len(real_local_stores:=[u for u in k.uops if u.op is UOps.STORE and any([lb in u.src for lb in local_buf])]), 3, \
|
||||
f"should have generated 3 BufferOps.STORE to the local buf but got {len(real_local_stores)}")
|
||||
self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.uop is UOps.LOAD and any([lb in u.vin for lb in local_buf])]), 3, \
|
||||
self.assertEqual(len(real_local_loads:=[u for u in k.uops if u.op is UOps.LOAD and any([lb in u.src for lb in local_buf])]), 3, \
|
||||
f"should have generated 3 BufferOps.LOAD to the local buf but got {len(real_local_loads)}")
|
||||
self.assertEqual((real_local_stores[1].vin[1].uop, real_local_stores[1].vin[1].arg), (UOps.CONST, 0))
|
||||
self.assertEqual((real_local_loads[1].vin[1].uop, real_local_loads[1].vin[1].arg), (UOps.CONST, 0))
|
||||
self.assertEqual((real_local_stores[1].src[1].op, real_local_stores[1].src[1].arg), (UOps.CONST, 0))
|
||||
self.assertEqual((real_local_loads[1].src[1].op, real_local_loads[1].src[1].arg), (UOps.CONST, 0))
|
||||
x = Tensor.randn(3,27,32).realize()
|
||||
helper_linearizer_ast(ast, [x], wanna_output=[x.numpy().std(axis=2, ddof=0).reshape(-1)])
|
||||
|
||||
@@ -261,9 +261,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
define_globals = [u for u in k.uops if u.uop is UOps.DEFINE_GLOBAL]
|
||||
self.assertEqual(len([u for u in k.uops if u.uop is UOps.LOAD and define_globals[1] in u.vin]), 7)
|
||||
self.assertEqual(len([u for u in k.uops if u.uop is UOps.ALU and u.arg is BinaryOps.ADD]), 25)
|
||||
define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
|
||||
self.assertEqual(len([u for u in k.uops if u.op is UOps.LOAD and define_globals[1] in u.src]), 7)
|
||||
self.assertEqual(len([u for u in k.uops if u.op is UOps.ALU and u.arg is BinaryOps.ADD]), 25)
|
||||
opts = [[Opt(op=OptOps.UPCAST, axis=0, amt=2)], [Opt(op=OptOps.UPCAST, axis=0, amt=4)]]
|
||||
x = Tensor.randn(8,7).softmax().realize()
|
||||
helper_linearizer_ast(ast, [x], opts=opts, wanna_output=[(x.numpy() - x.numpy().sum(axis=1, keepdims=True)).sum(axis=1)])
|
||||
@@ -288,12 +288,12 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
def get_recursive_children(x:UOp): return set.union(set(x.vin), *[get_recursive_children(v) for v in x.vin])
|
||||
def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
|
||||
loop = None
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.RANGE: loop = u
|
||||
if u.op is UOps.RANGE: loop = u
|
||||
elif loop is None: continue
|
||||
elif u.uop is UOps.ENDRANGE and loop in u.vin: loop = None
|
||||
elif u.op is UOps.ENDRANGE and loop in u.src: loop = None
|
||||
else: self.assertIn(loop, get_recursive_children(u), f"Any uop within a loop should depend on the loop: {u}")
|
||||
x = Tensor.randn(3, 27, 32).realize()
|
||||
helper_linearizer_ast(ast, [x], wanna_output= \
|
||||
@@ -358,7 +358,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD])
|
||||
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
|
||||
@@ -377,7 +377,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin.linearize()
|
||||
|
||||
assert len(lin.uops.uops) <= 7, "too many uops"
|
||||
a_bufs = [u.uop for u in lin.uops.uops[-1].vin[2].vin]
|
||||
a_bufs = [u.op for u in lin.uops.uops[-1].src[2].src]
|
||||
assert a_bufs == [UOps.LOAD, UOps.CONST]
|
||||
|
||||
def test_upcast_cse(self):
|
||||
@@ -389,7 +389,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
|
||||
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -401,11 +401,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.op is UOps.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
assert len(stores) == 1
|
||||
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -417,15 +417,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.op is UOps.STORE]
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL
|
||||
assert accs[0].dtype == stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# the second store is to gds with no upcasts
|
||||
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
|
||||
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
|
||||
assert accs[1].dtype == stores[1].src[-1].dtype == dtypes.float
|
||||
assert stores[1].src[0].op is UOps.DEFINE_GLOBAL
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
def test_upcast_multireduce_nested_local_upcast(self):
|
||||
@@ -449,7 +449,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
|
||||
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
def test_sum_acc_dtype(self):
|
||||
@@ -458,14 +458,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
|
||||
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
|
||||
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
|
||||
tests = (
|
||||
@@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
|
||||
prg = CompiledRunner(k.to_program())
|
||||
@@ -571,8 +571,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out)
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.WMMA:
|
||||
assert u.vin[-1].vin[0].uop != UOps.PHI
|
||||
if u.op is UOps.WMMA:
|
||||
assert u.src[-1].src[0].op != UOps.PHI
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_unroll_casted_phi(self):
|
||||
@@ -581,9 +581,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out)
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.WMMA:
|
||||
assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
|
||||
assert u.vin[-1].vin[0].uop != UOps.PHI
|
||||
if u.op is UOps.WMMA:
|
||||
assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
|
||||
assert u.src[-1].src[0].op != UOps.PHI
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_unroll_casted_phi_with_children(self):
|
||||
@@ -593,9 +593,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.WMMA:
|
||||
assert u.vin[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
|
||||
assert u.vin[-1].vin[0].uop != UOps.PHI
|
||||
if u.op is UOps.WMMA:
|
||||
assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
|
||||
assert u.src[-1].src[0].op != UOps.PHI
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_simple_unroll_no_between_phi_dependencies(self):
|
||||
@@ -604,11 +604,11 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
|
||||
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.PHI:
|
||||
assert u.vin[1].uop is UOps.ALU
|
||||
if u.op is UOps.PHI:
|
||||
assert u.src[1].op is UOps.ALU
|
||||
# children of PHI are placed after ENDRANGE
|
||||
if any(x.uop is UOps.PHI for x in u.vin):
|
||||
end_range = [i for i, x in enumerate(k.uops) if x.uop is UOps.ENDRANGE][0]
|
||||
if any(x.op is UOps.PHI for x in u.src):
|
||||
end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
|
||||
assert end_range < k.uops.uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
@@ -649,7 +649,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
@@ -680,10 +680,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
k.hand_coded_optimizations()
|
||||
uops = list(k.linearize().uops)
|
||||
# ignore kernel optimized IF statements for now
|
||||
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
|
||||
if if_op:=next((u for u in uops if u.op is UOps.IF), None):
|
||||
uops = uops[:uops.index(if_op)]
|
||||
assert len(set([u.uop for u in uops if u.uop in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
|
||||
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
|
||||
assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
|
||||
assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified"
|
||||
# TODO: once uops track min/max this will be fixed
|
||||
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
@@ -708,17 +708,17 @@ class TestLinearizer(unittest.TestCase):
|
||||
out = x.matmul(y)
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
|
||||
store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
|
||||
assert val.dtype == dtypes.float.vec(4) and val.op is not UOps.CAST
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_grouped_store_values(self):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
k = helper_linearizer_opt(out)[-1]
|
||||
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
|
||||
store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.CAST
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -729,16 +729,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
opt = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
|
||||
k = helper_linearizer_opt(out, opts=[opt])[-1]
|
||||
def get_recursive(uop): return set.union(set(uop.vin), [uop], *[get_recursive(v) for v in uop.vin])
|
||||
local_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_LOCAL for x in get_recursive(u.vin[0]))]
|
||||
global_stores = [u for u in k.uops if u.uop is UOps.STORE and any(x.uop is UOps.DEFINE_GLOBAL for x in get_recursive(u.vin[0]))]
|
||||
barrier = [u for u in k.uops if u.uop is UOps.BARRIER][0]
|
||||
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
|
||||
local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
|
||||
global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
|
||||
barrier = [u for u in k.uops if u.op is UOps.BARRIER][0]
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST
|
||||
assert store.src[-1].dtype == dtypes.float.vec(2) and store.src[-1].op is not UOps.CAST
|
||||
# check the children's vins
|
||||
assert barrier.vin == tuple(local_stores)
|
||||
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
|
||||
assert barrier.src == tuple(local_stores)
|
||||
assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -747,14 +747,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = helper_linearizer_opt(r)[-1]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
stores = [u for u in k.uops if u.op is UOps.STORE]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].vin[-1].uop is not UOps.CAST
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].src[-1].op is not UOps.CAST
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].vin[-1].dtype == dtypes.float
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
@@ -764,8 +764,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
|
||||
]
|
||||
k = helper_linearizer_ast(ast, [Tensor.empty(240*40).realize()], opts=[opt])[-1]
|
||||
out = [u for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
|
||||
out = [u for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
@@ -775,15 +775,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
k = helper_linearizer_ast(ast, [Tensor.empty(8*32).realize()], opts=[opt])[-1]
|
||||
out = [u for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2)
|
||||
out = [u for u in k.uops if u.op is UOps.STORE][0]
|
||||
assert out.src[-1].op is UOps.CAST and out.src[-1].dtype == dtypes.float.vec(2)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(k):
|
||||
return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
|
||||
return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.op is UOps.STORE and len(uop.src) == 3 and uop.src[2].dtype == dtypes.float.vec(4)]))
|
||||
|
||||
# TODO: express opts below as auto opts
|
||||
|
||||
@@ -1178,10 +1178,10 @@ class TestKernelOpts(unittest.TestCase):
|
||||
for k in lins:
|
||||
seen_bar = False
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.BARRIER:
|
||||
if u.op is UOps.BARRIER:
|
||||
assert not seen_bar, "redudant barrier"
|
||||
seen_bar = True
|
||||
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
|
||||
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
|
||||
|
||||
@unittest.skip("TODO: broken")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
@@ -1200,10 +1200,10 @@ class TestKernelOpts(unittest.TestCase):
|
||||
for k in lins:
|
||||
seen_bar = False
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.BARRIER:
|
||||
if u.op is UOps.BARRIER:
|
||||
assert not seen_bar, "redudant barrier"
|
||||
seen_bar = True
|
||||
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
|
||||
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -1225,10 +1225,10 @@ class TestKernelOpts(unittest.TestCase):
|
||||
for k in lins:
|
||||
seen_bar = False
|
||||
for u in k.uops:
|
||||
if u.uop is UOps.BARRIER:
|
||||
if u.op is UOps.BARRIER:
|
||||
assert not seen_bar, "redudant barrier"
|
||||
seen_bar = True
|
||||
elif (u.uop is UOps.LOAD or u.uop is UOps.STORE): seen_bar = False
|
||||
elif (u.op is UOps.LOAD or u.op is UOps.STORE): seen_bar = False
|
||||
|
||||
def test_upcasts(self):
|
||||
N = 16
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.codegen.uops import UOpGraph, UOps, PatternMatcher, UOp, UPat
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
|
||||
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
|
||||
self.assertEqual(uop1.uop, uop2.uop)
|
||||
self.assertEqual(uop1.op, uop2.op)
|
||||
self.assertEqual(uop1.dtype, uop2.dtype)
|
||||
self.assertEqual(uop1.arg, uop2.arg)
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
@@ -91,13 +91,13 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_vin_one(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)])
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
@@ -105,7 +105,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
|
||||
def test_vin_permutations(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -118,7 +118,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_vin_repeat(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=UPat(UOps.CONST)), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
@@ -127,7 +127,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", vin=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_len={3}), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
|
||||
@@ -157,7 +157,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
|
||||
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
uops.remove_childless(set(x for x in uops if x.uop in {UOps.STORE}))
|
||||
uops.remove_childless(set(x for x in uops if x.op in {UOps.STORE}))
|
||||
|
||||
self.assertEqual(len(uops.uops), 3)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
@@ -24,7 +24,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 1.0)
|
||||
|
||||
def test_where_const_fold(self):
|
||||
@@ -35,7 +35,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 2.0)
|
||||
|
||||
def test_const_cast(self):
|
||||
@@ -44,7 +44,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_cast_vectorized_fold(self):
|
||||
@@ -56,7 +56,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 0)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
@@ -67,10 +67,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 3)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.ALU)
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
self.assertEqual(out.vin[1].uop, UOps.CONST)
|
||||
self.assertEqual(out.vin[1].arg, 6)
|
||||
self.assertEqual(out.src[1].op, UOps.CONST)
|
||||
self.assertEqual(out.src[1].arg, 6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -226,7 +226,7 @@ class TestConstantFolding(unittest.TestCase):
|
||||
si = create_schedule([t.lazydata])
|
||||
assert len(si) == 1
|
||||
ji = lower_schedule_item(si[-1])
|
||||
assert any(uop.uop is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.uop for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.skip("not yet implemented")
|
||||
@@ -240,9 +240,9 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
assert endif.vin[0] is if_uop
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
assert nested_uops == (gmem, gidx0, idx, value)
|
||||
|
||||
@@ -261,9 +261,9 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
|
||||
uops = UOpGraph(outs)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
assert endif.vin[0] is if_uop
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
assert nested_uops == (gmem0, value0)
|
||||
|
||||
|
||||
@@ -29,23 +29,23 @@ class UOps(Enum):
|
||||
def ufix(dtype: Optional[DType], x): return UOp.const(dtype, x) if not isinstance(x, UOp) else x
|
||||
@dataclass(eq=False)
|
||||
class UOp:
|
||||
uop: UOps
|
||||
op: UOps
|
||||
dtype: Optional[DType] = None
|
||||
vin: Tuple[UOp, ...] = tuple()
|
||||
src: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def tuple(self): return (self.uop, self.dtype, self.vin, self.arg)
|
||||
def tuple(self): return (self.op, self.dtype, self.src, self.arg)
|
||||
def commutative(self) -> bool:
|
||||
return self.uop is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
|
||||
return self.op is UOps.ALU and self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR}
|
||||
@functools.cached_property
|
||||
def cmp_tuple(self):
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.uop.value, (self.arg if self.uop is not UOps.DEFINE_VAR else self.arg.expr) if self.uop is not UOps.ALU else \
|
||||
(type(self.uop), self.uop.value), self.dtype, self.vin)
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
(type(self.op), self.op.value), self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
||||
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
|
||||
def name(self, name:Optional[str]): return UOp(UOps.VAR, vin=(self,), arg=name)
|
||||
def name(self, name:Optional[str]): return UOp(UOps.VAR, src=(self,), arg=name)
|
||||
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
|
||||
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
|
||||
def __radd__(self, x): return UOp.alu(BinaryOps.ADD, ufix(self.dtype, x), self)
|
||||
@@ -75,35 +75,35 @@ class UOp:
|
||||
@staticmethod
|
||||
def cvar(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Set[UOp]: return set.union(set(self.vin), *[x.parents for x in self.vin])
|
||||
def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src])
|
||||
@property # parents with self
|
||||
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
||||
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.uop is UOps.DEFINE_VAR])
|
||||
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop is UOps.CONST: return u.arg
|
||||
if u.uop is UOps.DEFINE_VAR: return u.arg
|
||||
if u.uop is UOps.SPECIAL: return u.arg[2]-1
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.vin[0]) * (2**cast(int, uop_alu_resolve(u.vin[1])))
|
||||
if u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return u.arg
|
||||
if u.op is UOps.SPECIAL: return u.arg[2]-1
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.src[0]) * uop_alu_resolve(u.src[1])
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.SHL: return uop_alu_resolve(u.src[0]) * (2**cast(int, uop_alu_resolve(u.src[1])))
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.src[0]) + uop_alu_resolve(u.src[1])
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
# *** simplification logic ***
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UPat:
|
||||
uop: Optional[Union[UOps, Set[UOps]]] = None
|
||||
op: Optional[Union[UOps, Set[UOps]]] = None
|
||||
arg: Any = None
|
||||
vin: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
||||
src: Optional[Union[Tuple[UPat, ...], List[UPat], UPat]] = None
|
||||
name: Optional[str] = None
|
||||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
|
||||
@staticmethod
|
||||
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
||||
if u.uop is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.vin) == 0 else UPat.compile(u.vin[0], name or u.arg)
|
||||
return UPat(u.uop, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.vin]) if u.vin != () else None, name, u.dtype)
|
||||
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.src]) if u.src != () else None, name, u.dtype)
|
||||
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
|
||||
@@ -117,15 +117,15 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name is not None: store[pat.name] = uop
|
||||
if pat.arg is not None and __unmatch(pat.arg, uop.arg): return False
|
||||
if pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype): return False
|
||||
if pat.uop is not None and __unmatch(pat.uop, uop.uop): return False
|
||||
if pat.vin is None: return True
|
||||
if pat.op is not None and __unmatch(pat.op, uop.op): return False
|
||||
if pat.src is None: return True
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a UPat
|
||||
for vp in itertools.permutations(pat.vin) if isinstance(pat.vin,list) else ([pat.vin] if isinstance(pat.vin,tuple) else [(pat.vin,)*len(uop.vin)]):
|
||||
if len(uop.vin) != len(vp) and (len(uop.vin) not in pat.allow_len): return False
|
||||
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
||||
store.update(new_store)
|
||||
return True
|
||||
return False
|
||||
@@ -137,14 +137,14 @@ class PatternMatcher:
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
if isinstance(p, UOp): p = UPat.compile(p)
|
||||
assert p.uop is not None
|
||||
if isinstance(p.uop, set):
|
||||
for uop in p.uop: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
assert p.op is not None
|
||||
if isinstance(p.op, set):
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn))
|
||||
else:
|
||||
self.pdict[(p.uop, p.arg)].append((p, fxn))
|
||||
self.pdict[(p.op, p.arg)].append((p, fxn))
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.uop, uop.arg)], self.pdict[(uop.uop, None)]):
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
store: Dict[str, UOp] = {}
|
||||
if _match(uop, p, store): return fxn(**store)
|
||||
return None
|
||||
@@ -152,7 +152,7 @@ class PatternMatcher:
|
||||
def sum_collapse(phi_input, loop, val1, val2):
|
||||
for v1,v2 in [(val1, val2), (val2, val1)]:
|
||||
if loop not in v1.parents:
|
||||
loop_range = loop.vin[1]-loop.vin[0]
|
||||
loop_range = loop.src[1]-loop.src[0]
|
||||
ret = v1*loop_range.cast(v1.dtype)
|
||||
return UOp(UOps.PHI, phi_input.dtype, (phi_input, v2))+ret
|
||||
return None
|
||||
@@ -168,34 +168,34 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# arange loop folding (early)
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, vin=(UPat(UOps.ALU, BinaryOps.CMPLT, vin=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
|
||||
vin=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, vin=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
(UPat(UOps.ALU, TernaryOps.WHERE, src=(UPat(UOps.ALU, BinaryOps.CMPLT, src=(
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="idx"), UPat(UOps.ALU, BinaryOps.MUL,
|
||||
src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="phi_input", vin=(UPat(UOps.RANGE, name="loop"),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.GEP, name="phi_input",
|
||||
vin=(UPat(UOps.DEFINE_ACC, vin=(UPat(UOps.RANGE, name="loop"),)),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input",
|
||||
src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
# deal with UNMUL
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"),
|
||||
UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
||||
UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
||||
lambda c1,c2,v: v if c1.arg == c2.arg else None),
|
||||
(UOp(UOps.UNMUL, vin=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
||||
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
|
||||
(UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
|
||||
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))),
|
||||
# max on special can go away (TODO: special should be variable, same thing applies)
|
||||
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
|
||||
# const rules
|
||||
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.DEFINE_ACC, vin=tuple()), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, vin=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
||||
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
||||
(UPat(UOps.DEFINE_ACC, name="root", vin=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
|
||||
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
(UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
# max -2147483648
|
||||
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
||||
# -(-x) -> x
|
||||
@@ -209,7 +209,7 @@ constant_folder = PatternMatcher([
|
||||
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
|
||||
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
||||
@@ -222,8 +222,8 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var('x') * 0, lambda x: x if isinstance(x.arg, float) and math.isnan(x.arg) else UOp.const(x.dtype, 0)),
|
||||
(UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
# ** load/store folding **
|
||||
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"),
|
||||
UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
|
||||
(UPat(UOps.STORE, src=(UPat(name="buf"), UPat(name="idx"),
|
||||
UPat(UOps.LOAD, src=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)),
|
||||
# ** two stage add/sub folding **
|
||||
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
||||
@@ -247,17 +247,17 @@ constant_folder = PatternMatcher([
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
# store float4/float2 directly (remove CAST/GEP)
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(4)))), UOp.store),
|
||||
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, src=tuple(UOp(UOps.GEP, arg=i, src=(UOp.var("val"),)) for i in range(2)))), UOp.store),
|
||||
# CAST-PHI-GEP -> PHI-CAST
|
||||
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
|
||||
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
|
||||
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
(UPat(UOps.CAST, name="root", src=tuple(UPat(UOps.PHI, src=(UPat(UOps.GEP, i, src=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
|
||||
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
|
||||
# NEG/CMPLT -> CMPLT
|
||||
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
@@ -271,8 +271,8 @@ class UOpGraph:
|
||||
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
||||
def __getitem__(self, index) -> UOp: return self.uops[index]
|
||||
|
||||
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]
|
||||
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.op is UOps.DEFINE_VAR], key=lambda v: v.expr)
|
||||
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.op is UOps.DEFINE_GLOBAL]
|
||||
|
||||
@property
|
||||
def uops(self):
|
||||
@@ -285,7 +285,7 @@ class UOpGraph:
|
||||
|
||||
def print(self):
|
||||
for i,u in enumerate(self):
|
||||
print(f"{i:4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}")
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
||||
|
||||
def graph_rewrite(self, sink, pm):
|
||||
# recursive rewrite
|
||||
@@ -305,7 +305,7 @@ class UOpGraph:
|
||||
recurse_cnt += 1
|
||||
changed += recurse_cnt
|
||||
# NOTE: this changes UOp, so we have to delete caches
|
||||
up.vin = tuple(rewrite(x) for x in up.vin)
|
||||
up.src = tuple(rewrite(x) for x in up.src)
|
||||
if 'parents' in up.__dict__: delattr(up, 'parents')
|
||||
if 'cmp_tuple' in up.__dict__: delattr(up, 'cmp_tuple')
|
||||
# replace with cached nodes
|
||||
@@ -327,16 +327,16 @@ class UOpGraph:
|
||||
n = unprocessed_nodes.pop(0)
|
||||
if n in all_nodes: continue
|
||||
all_nodes[n] = None
|
||||
for x in n.vin:
|
||||
for x in n.src:
|
||||
early_in_degree[n] += 1
|
||||
children[x].append(n)
|
||||
unprocessed_nodes += list(n.vin)
|
||||
unprocessed_nodes += list(n.src)
|
||||
early_queue = [x for x in all_nodes if early_in_degree[x] == 0]
|
||||
replace_nodes: Dict[UOp, UOp] = {}
|
||||
while len(early_queue):
|
||||
n = early_queue.pop(0)
|
||||
if n in replace_nodes: continue
|
||||
key = (n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
|
||||
key = (n.op, n.dtype, tuple(replace_nodes.get(x, x) for x in n.src), n.arg)
|
||||
if found:=self.nodes.get(key): replace_nodes[n] = found
|
||||
else: replace_nodes[n] = self.nodes[key] = UOp(*key)
|
||||
for x in children[n]:
|
||||
@@ -367,29 +367,29 @@ class UOpGraph:
|
||||
def add_parents(u:UOp):
|
||||
if u in nodes: return
|
||||
nodes[u] = None
|
||||
for x in u.vin:
|
||||
for x in u.src:
|
||||
add_parents(x)
|
||||
in_degree[u] += 1
|
||||
graph[x].append(u)
|
||||
if u.uop is UOps.RANGE: loops.append(u)
|
||||
if u.uop is UOps.IF: ifs.append(u)
|
||||
sink = UOp(UOps.SINK, None, tuple(x for x in sink.vin if x.uop is not UOps.NOOP))
|
||||
if u.op is UOps.RANGE: loops.append(u)
|
||||
if u.op is UOps.IF: ifs.append(u)
|
||||
sink = UOp(UOps.SINK, None, tuple(x for x in sink.src if x.op is not UOps.NOOP))
|
||||
add_parents(sink)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
||||
if x.uop is UOps.SINK: return set()
|
||||
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.uop is not end]))
|
||||
if x.op is UOps.SINK: return set()
|
||||
return set.union(set((x,)) if include_self else set(), *([get_recursive_children(u, end, True) for u in graph[x] if x.op is not end]))
|
||||
# scope children impact the toposort and END* insertion
|
||||
end_for_uop = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.PHI, UOps.ENDRANGE)}
|
||||
scope_children = {p:get_recursive_children(p, end_for_uop[p.uop][0]) for p in (loops+ifs)[::-1]}
|
||||
scope_children = {p:get_recursive_children(p, end_for_uop[p.op][0]) for p in (loops+ifs)[::-1]}
|
||||
|
||||
queue: List = []
|
||||
def push(u):
|
||||
priority = 0
|
||||
# prefer uops that are loop children
|
||||
for l, ss in scope_children.items():
|
||||
if l.uop is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
||||
if l.op is UOps.RANGE and u in ss: priority -= l.arg[0]*1000 + l.arg[1]
|
||||
heapq.heappush(queue, (priority, u))
|
||||
|
||||
for u in nodes:
|
||||
@@ -403,20 +403,20 @@ class UOpGraph:
|
||||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(p,x)
|
||||
if x.uop is UOps.DEFINE_ACC and len(x.vin):
|
||||
idx = min([self._uops.index(l) for l in x.vin])
|
||||
if x.op is UOps.DEFINE_ACC and len(x.src):
|
||||
idx = min([self._uops.index(l) for l in x.src])
|
||||
self._uops.insert(idx, x)
|
||||
else:
|
||||
self._uops.append(x)
|
||||
for u, ss in scope_children.items():
|
||||
if x in ss:
|
||||
ss.remove(x)
|
||||
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.uop][1], None, (u,)))
|
||||
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
|
||||
for u in graph[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
assert self._uops[-1].uop is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
||||
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
||||
self._uops = self._uops[:-1]
|
||||
|
||||
if type_verify: self.type_verify()
|
||||
@@ -431,34 +431,34 @@ class UOpGraph:
|
||||
dont_count: Set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.LOAD:
|
||||
dont_count = dont_count.union(u.vin[1].sparents)
|
||||
if len(u.vin) > 3: dont_count = dont_count.union(u.vin[2].sparents)
|
||||
elif u.uop is UOps.STORE:
|
||||
dont_count = dont_count.union(u.vin[1].sparents)
|
||||
if len(u.vin) > 3: dont_count = dont_count.union(u.vin[3].sparents)
|
||||
if u.op is UOps.LOAD:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
||||
elif u.op is UOps.STORE:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
for u in self.uops:
|
||||
if u.uop is UOps.RANGE:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.vin[1])
|
||||
elif u.uop is UOps.ENDRANGE:
|
||||
mults *= uop_alu_resolve(u.src[1])
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.uop is UOps.LOAD:
|
||||
elif u.op is UOps.LOAD:
|
||||
assert u.dtype is not None
|
||||
mem += u.dtype.itemsize * mults
|
||||
elif u.uop is UOps.STORE:
|
||||
assert u.vin[2].dtype is not None
|
||||
mem += u.vin[2].dtype.itemsize * mults
|
||||
elif u.uop is UOps.ALU and u not in dont_count:
|
||||
elif u.op is UOps.STORE:
|
||||
assert u.src[2].dtype is not None
|
||||
mem += u.src[2].dtype.itemsize * mults
|
||||
elif u.op is UOps.ALU and u not in dont_count:
|
||||
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
|
||||
elif u.uop is UOps.WMMA and u not in dont_count:
|
||||
elif u.op is UOps.WMMA and u not in dont_count:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
||||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype
|
||||
uop, arg, vin, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.DEFINE_ACC: arg = arg[0]
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
|
||||
@@ -77,13 +77,12 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
||||
|
||||
def _tree(luop:Union[LazyOp,UOp], cycles, cnt, prefix=""):
|
||||
cnt[0] += 1
|
||||
if len(src:=luop.vin if hasattr(luop,'vin')else luop.src) == 0:
|
||||
return [f"━━ {prefix}{(luop.op if hasattr(luop, 'op') else luop.uop).name} {luop.arg if luop.arg else ''}"]
|
||||
if len(luop.src) == 0: return [f"━━ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
|
||||
if (lid := id(luop)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
||||
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {(luop.op if hasattr(luop,'op')else luop.uop).name}"]
|
||||
return [f"━⬆︎ goto {cycles[id(luop)][0]}: {luop.op.name}"]
|
||||
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
||||
lines = [f"━┳ {prefix}{(luop.op if hasattr(luop,'op')else luop.uop).name} {luop.arg if luop.arg else ''}"]
|
||||
childs = [_tree(c, cycles, cnt) for c in src[:]]
|
||||
lines = [f"━┳ {prefix}{luop.op.name} {luop.arg if luop.arg else ''}"]
|
||||
childs = [_tree(c, cycles, cnt) for c in luop.src[:]]
|
||||
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
||||
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
||||
|
||||
@@ -95,7 +94,7 @@ def graph_uops(uops:List[UOp]):
|
||||
UOps.RANGE: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"}
|
||||
G = nx.DiGraph()
|
||||
for u in uops:
|
||||
if u.uop in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
||||
G.add_node(uops.index(u), label=f"{str(u.uop)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.uop, "#ffffff")) # noqa: E501
|
||||
for v in u.vin: G.add_edge(uops.index(v), uops.index(u))
|
||||
if u.op in {UOps.ENDRANGE, UOps.ENDIF}: continue
|
||||
G.add_node(uops.index(u), label=f"{str(u.op)[5:]}{(' '+str(u.arg)) if u.arg is not None else ''}\n{str(u.dtype)}", style="filled", fillcolor=colors.get(u.op, "#ffffff")) # noqa: E501
|
||||
for v in u.src: G.add_edge(uops.index(v), uops.index(u))
|
||||
save_graph(G, f'{GRAPHPATH}.uops', '-Grankdir=LR')
|
||||
|
||||
@@ -131,22 +131,22 @@ class PTXRenderer(Renderer):
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*self.render_bra(f"IF_{r[vin[0]][1:]}_{cast(List, uops._uops).index(u)}", _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True)))
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDRANGE:
|
||||
kk(self.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, self.types[dtypes.int]))
|
||||
self.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[vin[0]][1:]}", pred))
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"IF_{r[vin[0].vin[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
|
||||
kk(f"IF_{r[vin[0].src[0]][1:]}_{cast(List, uops._uops).index(vin[0])}:")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
assert vin[0].dtype == dtypes.int64, "store isn't int64"
|
||||
assert vin[1].uop is UOps.CONST, f"store isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
assert vin[1].op is UOps.CONST, f"store isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
f"st{mem_type}.v{vin[2].dtype.count}.{self.mem_types[vin[2].dtype.scalar()]} [{r[vin[0]]}+{vin[1].arg}], {{{', '.join(r[vin[2]])}}};")
|
||||
@@ -178,8 +178,8 @@ class PTXRenderer(Renderer):
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[0].dtype == dtypes.int64, "load isn't int64"
|
||||
assert vin[1].uop is UOps.CONST, f"load isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
assert vin[1].op is UOps.CONST, f"load isn't const {u}"
|
||||
mem_type = '.shared' if vin[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global'
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
@@ -227,44 +227,44 @@ class PTXRenderer(Renderer):
|
||||
|
||||
ptx_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
||||
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="mul")]),
|
||||
lambda root, mul, const: UOp(UOps.ALU, root.dtype, (mul, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHL)),
|
||||
(UPat(UOps.ALU, BinaryOps.IDIV, name="root", dtype=set([dt for dt in dtypes.fields().values() if dtypes.is_int(dt)]),
|
||||
vin=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
||||
src=[UPat(UOps.CONST, set([2**i for i in range(64)]), name="const"), UPat(name="div")]),
|
||||
lambda root, div, const: UOp(UOps.ALU, root.dtype, (div, UOp.const(root.dtype, int(math.log2(const.arg)))), BinaryOps.SHR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPNE, (UPat(dtype=dtypes.bool),UPat()), "root"), lambda root: UOp(root.op, root.dtype, root.src, BinaryOps.XOR)),
|
||||
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(name="x", dtype=dtypes.bool),UPat(name="y")), "root"),
|
||||
lambda root,x,y: UOp(root.uop, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
lambda root,x,y: UOp(root.op, root.dtype, (UOp(UOps.ALU, dtypes.bool, (x,), UnaryOps.NEG), y), BinaryOps.MUL)),
|
||||
(UPat(UOps.ALU, BinaryOps.ADD,
|
||||
[UPat(name="non_muls"), UPat(UOps.ALU, BinaryOps.MUL, name="muls")], "root"),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.vin + (non_muls,), TernaryOps.MULACC)),
|
||||
lambda root, muls, non_muls: UOp(UOps.ALU, root.dtype, muls.src + (non_muls,), TernaryOps.MULACC)),
|
||||
*[(UPat(UOps.ALU, op, dtype=dtypes.half, name="x"),
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.uop, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.vin]), x.arg),)))
|
||||
lambda x: UOp(UOps.CAST, dtypes.half, (UOp(x.op, dtypes.float32, tuple([UOp(UOps.CAST, dtypes.float32, (vv,)) for vv in x.src]), x.arg),)))
|
||||
for op in PTXRenderer.asm_for_op.keys() if op not in PTXRenderer.supports_half],
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, vin=(UPat(),UPat())),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.uop, dtypes.uint8, root.vin, root.arg),))),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
||||
lambda root,z: UOp(root.uop, root.dtype, root.vin[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", vin=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.uop, root.dtype, root.vin[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(name="x"),UPat(name="y"),UPat(name="z"),UPat(name="k"))),
|
||||
lambda root,x,y,z,k: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.int8, (x,y,z,UOp(UOps.CAST, dtypes.uint8, (k,)))),), root.arg)),
|
||||
(UPat(UOps.LOAD, name="root", dtype=dtypes.bool, src=(UPat(),UPat())),
|
||||
lambda root: UOp(UOps.CAST, dtypes.bool, (UOp(root.op, dtypes.uint8, root.src, root.arg),))),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool), UPat())),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(name="z",dtype=dtypes.bool))),
|
||||
lambda root,z: UOp(root.op, root.dtype, root.src[:2] + (UOp(UOps.CAST, dtypes.uint8, (z,)),), root.arg)),
|
||||
(UPat(UOps.STORE, name="root", src=(UPat(),UPat(),UPat(),UPat(name="g", dtype=dtypes.int))),
|
||||
lambda root,g: UOp(root.op, root.dtype, root.src[:3] + (UOp(UOps.CAST, dtypes.bool, (g,)),), root.arg)),
|
||||
# ptr_ar (load/store)
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, vin=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
lambda root, alu, const: UOp(root.uop, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(const.dtype, root.vin[0].dtype.itemsize)*const)+root.vin[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=[UPat(name="alu"), UPat(UOps.CONST, name="const")]))),
|
||||
lambda root, alu, const: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(const.dtype, root.src[0].dtype.itemsize)*const)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(UOps.CONST, name="const"))),
|
||||
lambda root, const: UOp(root.uop, root.dtype, (root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg * root.vin[0].dtype.itemsize),
|
||||
)+root.vin[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, vin=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
lambda root, const: UOp(root.op, root.dtype, (root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, const.arg * root.src[0].dtype.itemsize),
|
||||
)+root.src[2:])),
|
||||
(UPat({UOps.LOAD, UOps.STORE}, name="root", allow_len={2,3,4,5}, src=(UPat({UOps.DEFINE_LOCAL,UOps.DEFINE_GLOBAL}),
|
||||
UPat(name="alu"))), # no const here
|
||||
lambda root, alu: UOp(root.uop, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.vin[0].dtype.itemsize)+root.vin[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.vin[2:])),
|
||||
lambda root, alu: UOp(root.op, root.dtype,
|
||||
(alu.cast(dtypes.int64)*UOp.const(dtypes.int64, root.src[0].dtype.itemsize)+root.src[0].cast(dtypes.int64),
|
||||
UOp.const(dtypes.int64, 0))+root.src[2:])),
|
||||
])
|
||||
|
||||
@@ -104,10 +104,10 @@ class CStyleLanguage(Renderer):
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
|
||||
child_count = Counter(v for ru in uops for v in ru.vin)
|
||||
child_count = Counter(v for ru in uops for v in ru.src)
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[vin[0]]}) {{")
|
||||
@@ -118,7 +118,7 @@ class CStyleLanguage(Renderer):
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[2].dtype is not None
|
||||
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
rendered_store = self.render_store(r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[vin[3]]}) {{ {rendered_store} }}" if len(vin) > 3 else rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
@@ -138,7 +138,7 @@ class CStyleLanguage(Renderer):
|
||||
kk(f"int {args[1]} = {self.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
val = self.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].op is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = self.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
@@ -169,7 +169,7 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
from_ssa = vin[0].uop in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
from_ssa = vin[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[vin[0]] if from_ssa else f"{(r[vin[0]])}") + (f"[{args}]" if vin[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
@@ -235,7 +235,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
||||
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
||||
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
@@ -285,7 +285,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
||||
|
||||
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
||||
for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
||||
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
||||
@@ -370,7 +370,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
|
||||
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
||||
|
||||
for arg in dedup([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
@@ -379,7 +379,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.op is UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
@@ -78,7 +78,7 @@ class LLVMRenderer(Renderer):
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
@@ -101,7 +101,7 @@ class LLVMRenderer(Renderer):
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
uop,dtype,vin,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[vin[2]], vin[2].dtype, vin[0].dtype)
|
||||
if len(vin) > 3:
|
||||
@@ -115,7 +115,7 @@ class LLVMRenderer(Renderer):
|
||||
lvars[vin[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE:
|
||||
@@ -147,7 +147,7 @@ class LLVMRenderer(Renderer):
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
while backward.op is UOps.PHI: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPNE) else vin[0].dtype)
|
||||
|
||||
@@ -191,7 +191,7 @@ class PythonRenderer(Renderer):
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
lops = [(u.op, u.dtype, [uops.uops.index(v) for v in u.src], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
|
||||
class PythonCompiler(Compiler):
|
||||
|
||||
Reference in New Issue
Block a user