linearizer: change globals to merge into left axis/gridDims.x first (#5033)

* linearizer: change order of collapse to be left-most

also fixes Variable max size to be correct and add docs for the off
parameter

* fix multiple global dim oversizes

* add passing variable test and reorganize tests

* use assert RuntimeError for failing test
This commit is contained in:
Francis Lam
2024-06-23 15:53:15 -07:00
committed by GitHub
parent 69f116a7e1
commit b563cd52ed
2 changed files with 43 additions and 24 deletions

View File

@@ -608,26 +608,46 @@ class TestLinearizer(unittest.TestCase):
assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
# pad sizes with ones if necessary
_assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2,1,1])
_assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3,1])
# check reverse dims
_assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2,1])
_assert_grouped_dims("gidx", (2,3,4,), (16,16,16,), False, [2,3,4])
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
# _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2]) # this is the new linearizer way
# _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2]) # this is the new linearizer way
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,4,6]) # this is the old linearizer way - TODO: remove
# test splitting globals
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4])
_assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12])
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4])
_assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24])
with self.assertRaises(AssertionError): # dim too large and not factorable
# collapse on onto the left most axis
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
_assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2])
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)])
# collapse on left-most available axis (the left most is too small)
_assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5])
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2])
# TODO: support sint collapse
with self.assertRaises(RuntimeError):
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5])
# dim too large and not factorable
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,)
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,)
with self.assertRaises(AssertionError): # too large for sizes
# too large for sizes
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,)
with self.assertRaises(AssertionError): # variable too large
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 17),3,4), (16,16,16,), False,)
# variable too large
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()

View File

@@ -22,6 +22,7 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio
Keyword arguments:
prefix -- the prefix to use for the size Variable names.
off -- the starting index for the size Variable names.
dims -- the global or local dims of the full shape.
max_sizes -- the maximum values for each size in (x, y, z) order.
reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
@@ -34,41 +35,39 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio
# initialize the map of dims to size with a single dim in each size axis
# TODO: support sint properly
size_dims:List[List[Tuple[int, sint]]] = [[(dim_idx, dim)] for dim_idx, dim in enumerate(dims)]
size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
# reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
# TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
if reverse_dims: size_dims = size_dims[::-1]
# ensure that the initial dims initially fit the valid size axes
for size_idx, max_sz in [(i, sz) for i, sz in enumerate(max_sizes[:len(size_dims)]) if size_dims[i][0][1] > sz]:
for size_idx in range(min(len(max_sizes), len(size_dims))):
# if the initial dim is too large, split the dim to separate size axes, if possible
dim_idx, dim_max = size_dims[size_idx][0]
assert isinstance(dim_max, int), "variable shape too large for size"
for factor in range(2, int(dim_max**0.5)+1):
if dim_max % factor == 0 and dim_max // factor <= max_sz:
size_dims = size_dims[:size_idx] + [[(dim_idx, dim_max//factor)], [(dim_idx, factor)]] + size_dims[size_idx+1:]
dim_idx, dim, dim_max = size_dims[size_idx][0]
if dim_max <= (max_sz:=max_sizes[size_idx]): continue
assert isinstance(dim, int), "variable shape too large for size"
for factor in range(2, int(dim**0.5)+1):
if dim % factor == 0 and dim // factor <= max_sz:
size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
break
assert size_dims[size_idx][0][1] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim_max} > {max_sz}"
assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
# compress the extra dims, collapsing them onto the left-most valid size axis
# for run_process_replay, collapse onto the right-most dim to compare the outputs. TODO: remove
if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1]
cur_size_idx = 0
while len(size_dims) > len(max_sizes):
if prod([dim_max for (_, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][1] < max_sizes[cur_size_idx]:
if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1]
# construct the final dim idx variables from the the portions of the size variables
sizes, idxs = [prod([dim_max for (_, dim_max) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
for size_idx, size_var in enumerate(size_vars):
for dim_idx, dim_max in size_dims[size_idx]:
idxs[dim_idx] += (size_var % dim_max) * (idxs[dim_idx].max+1)
size_var //= dim_max
for dim_idx, dim, _ in size_dims[size_idx]:
idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
size_var //= dim
# pad the final sizes array to the proper length if necessary
return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))