make some grouped_dim test work (#5415)

next need to support max size per dim, splitting and correct way to do reverse or arbitrary permute global dims
This commit is contained in:
chenyu
2024-07-12 14:22:50 -04:00
committed by GitHub
parent b7cc75a9df
commit 76125c07be
2 changed files with 37 additions and 34 deletions

View File

@@ -617,59 +617,60 @@ class TestLinearizer(unittest.TestCase):
end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0]
assert end_range < k.uops.uops.index(u)
@unittest.skip("this changed. TODO: bring test back")
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
idxs, loop_idxs, sizes = get_grouped_dims(prefix, 0, dims, max_sizes, reverse_dims)
# TODO: get_grouped_dims max_sizes should be 3 int tuple
# TODO: fix reverse_dims
max_sizes = 3
idxs, loop_idxs = get_grouped_dims(prefix, 0, dims, max_sizes)
sizes = [x.arg[2] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
for i in range(len(dims)):
assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
for i in range(len(loop_idxs)):
assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# for i in range(len(dims)):
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
# for i in range(len(loop_idxs)):
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# pad sizes with ones if necessary
_assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2,1,1])
_assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3,1])
# no-op
_assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2])
_assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2,1])
# _assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2])
_assert_grouped_dims("gidx", (2,3,4,), (16,16,16,), False, [2,3,4])
# test splitting globals
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12])
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4])
_assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24])
# _assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4])
# _assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12])
# _assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4])
# _assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24])
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2])
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)])
# _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2])
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2])
# _assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5])
# _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2])
# TODO: support sint collapse
with self.assertRaises(RuntimeError):
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5])
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,)
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,)
# # dim too large and not factorable
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,)
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,)
# too large for sizes
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,)
# # too large for sizes
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,)
# variable too large
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
# # variable too large
# with self.assertRaises(AssertionError):
# get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
def test_div_collapse(self):
def helper(t, msg, max_ops=0):