UOps.PHI -> UOps.ASSIGN [run_process_replay] (#6383)

This commit is contained in:
George Hotz
2024-09-06 12:38:35 +08:00
committed by GitHub
parent 002303c145
commit 86d34daac9
9 changed files with 29 additions and 29 deletions

View File

@@ -865,7 +865,7 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> LOAD -> RANGE -> PHI
# RANGE -> LOAD -> RANGE -> ASSIGN
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
def test_three_nested_range(self):
@@ -874,7 +874,7 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
# RANGE -> RANGE -> LOAD -> RANGE -> ASSIGN
# NOTE: nothing should toposort between the first two ranges
#assert ranges[0]+1 == ranges[1]
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
@@ -884,7 +884,7 @@ class TestLinearizer(unittest.TestCase):
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> PHI
# RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN
assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]])
assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:])
@@ -895,7 +895,7 @@ class TestLinearizer(unittest.TestCase):
out = (a + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> PHI
# LOAD -> RANGE -> LOAD -> ASSIGN
assert lin.uops[ranges[0]-2].op is UOps.LOAD
def test_range_outer_op_before_phi_nested_range(self):
@@ -906,11 +906,11 @@ class TestLinearizer(unittest.TestCase):
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
#if getenv("PTX"):
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> PHI
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> ASSIGN
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+6
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> ASSIGN
#else:
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert ranges[1] == ranges[0]+3
@@ -920,7 +920,7 @@ class TestLinearizer(unittest.TestCase):
a = Tensor.randn(4, 1).realize()
out = a.sum() * a.sum()
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
# RANGE -> LOAD -> PHI -> ALU
# RANGE -> LOAD -> ASSIGN -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
@@ -928,7 +928,7 @@ class TestLinearizer(unittest.TestCase):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
# RANGE -> LOAD -> PHI -> ALU
# RANGE -> LOAD -> ASSIGN -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE)
assert lin.uops[end+1].op is UOps.ALU
@@ -1128,7 +1128,7 @@ class TestLinearizer(unittest.TestCase):
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1]
for u in k.uops:
if u.op is UOps.WMMA:
assert u.src[-1].src[0].op != UOps.PHI
assert u.src[-1].src[0].op != UOps.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
@@ -1140,12 +1140,12 @@ class TestLinearizer(unittest.TestCase):
for u in k.uops:
if u.op is UOps.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
assert u.src[-1].src[0].op != UOps.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation")
def test_tensor_cores_unroll_casted_phi_with_children(self):
# all PHI children are outside the loop
# all ASSIGN children are outside the loop
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
r = x.matmul(y, acc_dtype=tc.dtype_out).relu()
@@ -1153,19 +1153,19 @@ class TestLinearizer(unittest.TestCase):
for u in k.uops:
if u.op is UOps.WMMA:
#assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2]))
assert u.src[-1].src[0].op != UOps.PHI
assert u.src[-1].src[0].op != UOps.ASSIGN
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_simple_unroll_no_between_phi_dependencies(self):
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
r = (x@y).relu()
k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1]
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x PHI -> ENDRANGE
# the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE
for u in k.uops:
if u.op is UOps.PHI:
if u.op is UOps.ASSIGN:
assert u.src[1].op is UOps.ALU
# children of PHI are placed after ENDRANGE
if any(x.op is UOps.PHI for x in u.src):
# children of ASSIGN are placed after ENDRANGE
if any(x.op is UOps.ASSIGN for x in u.src):
end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
assert end_range < k.uops.index(u)
@@ -1292,7 +1292,7 @@ class TestLinearizer(unittest.TestCase):
if if_op:=next((u for u in uops if u.op is UOps.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both"
assert len([u for u in uops if u.op is UOps.PHI]) == 0, "PHI should have been simplified"
assert len([u for u in uops if u.op is UOps.ASSIGN]) == 0, "ASSIGN should have been simplified"
# TODO: once uops track min/max this will be fixed
#assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"