UOps.RANGE toposort spec (#4660)

* use iterator

* nested loops and outer loads

* uop after phi
This commit is contained in:
qazal
2024-05-21 04:38:20 +08:00
committed by GitHub
parent 0d9e623d83
commit b33c827aed
2 changed files with 66 additions and 2 deletions

View File

@@ -108,6 +108,69 @@ class TestLinearizer(unittest.TestCase):
self.assertEqual(k.uops.uops[-1].uop, UOps.ENDIF)
self.assertLess(k.uops.uops.index([x for x in k.uops.uops if x.uop is UOps.STORE][-1]), k.uops.uops.index(k.uops.uops[-1]))
def test_two_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
# RANGE -> LOAD -> RANGE -> PHI
assert ranges[1] == ranges[0]+2
assert lin.uops[ranges[0]+1].uop is UOps.LOAD
def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
# RANGE -> RANGE -> LOAD -> RANGE -> PHI
assert ranges[2] == ranges[1]+2 == ranges[0]+3
assert lin.uops[ranges[1]+1].uop is UOps.LOAD
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
# RANGE -> 4x ALU -> RANGE -> 9x ALU + 1x LOAD -> PHI
assert ranges[1] == ranges[0]+5
assert lin.uops[ranges[1]+11].uop is UOps.ENDRANGE
def test_range_outer_op_before_phi(self):
a = Tensor.randn(4, 1).realize()
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()[0]])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> PHI
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()[0]])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.uop is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> PHI
assert lin.uops[ranges[0]-2].uop is UOps.LOAD
assert ranges[1] == ranges[0]+3
assert [x.uop for x in lin.uops[ranges[0]+1:ranges[0]+3]] == [UOps.LOAD, UOps.ALU]
def test_range_outer_op_after_phi(self):
a = Tensor.randn(4, 1).realize()
out = a.sum() * a.sum()
lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
assert lin.uops[end+1].uop is UOps.ALU
def test_range_outer_op_after_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum()
lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0]
# RANGE -> LOAD -> PHI -> ALU
end = max(i for i,u in enumerate(lin.uops) if u.uop is UOps.ENDRANGE)
assert lin.uops[end+1].uop is UOps.ALU
@unittest.expectedFailure
def test_early_end_local(self):
shape, output_shape = (32,), (1,)