From b563cd52eddaede18d01405ffba470b0c215bbcc Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Sun, 23 Jun 2024 15:53:15 -0700 Subject: [PATCH] linearizer: change globals to merge into left axis/gridDims.x first (#5033) * linearizer: change order of collapse to be left-most also fixes Variable max size to be correct and add docs for the off parameter * fix multiple global dim oversizes * add passing variable test and reorganize tests * use assert RuntimeError for failing test --- test/test_linearizer.py | 36 ++++++++++++++++++++++++++-------- tinygrad/codegen/linearizer.py | 31 ++++++++++++++--------------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 0a8f34b817..76431ed40c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -608,26 +608,46 @@ class TestLinearizer(unittest.TestCase): assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}" assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}" + # pad sizes with ones if necessary _assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2,1,1]) _assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3,1]) + + # check reverse dims _assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2,1]) _assert_grouped_dims("gidx", (2,3,4,), (16,16,16,), False, [2,3,4]) - _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5]) - # _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2]) # this is the new linearizer way - # _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2]) # this is the new linearizer way - _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,4,6]) # this is the old linearizer way - TODO: remove + + # test splitting globals _assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4]) _assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12]) _assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4]) + _assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24]) - with self.assertRaises(AssertionError): # dim too large and not factorable + # collapse on onto the left most axis + _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5]) + _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2]) + _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)]) + + # collapse on left-most available axis (the left most is too small) + _assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5]) + _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2]) + + # TODO: support sint collapse + with self.assertRaises(RuntimeError): + _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5]) + + # dim too large and not factorable + with self.assertRaises(AssertionError): get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,) + with self.assertRaises(AssertionError): + get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,) - with self.assertRaises(AssertionError): # too large for sizes + # too large for sizes + with self.assertRaises(AssertionError): get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,) - with self.assertRaises(AssertionError): # variable too large - get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 17),3,4), (16,16,16,), False,) + # variable too large + with self.assertRaises(AssertionError): + get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,) def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 7797703d84..5e4c8bedd9 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -22,6 +22,7 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio Keyword arguments: prefix -- the prefix to use for the size Variable names. + off -- the starting index for the size Variable names. dims -- the global or local dims of the full shape. max_sizes -- the maximum values for each size in (x, y, z) order. reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x). @@ -34,41 +35,39 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio # initialize the map of dims to size with a single dim in each size axis # TODO: support sint properly - size_dims:List[List[Tuple[int, sint]]] = [[(dim_idx, dim)] for dim_idx, dim in enumerate(dims)] + size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)] # reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right) # TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance if reverse_dims: size_dims = size_dims[::-1] # ensure that the initial dims initially fit the valid size axes - for size_idx, max_sz in [(i, sz) for i, sz in enumerate(max_sizes[:len(size_dims)]) if size_dims[i][0][1] > sz]: + for size_idx in range(min(len(max_sizes), len(size_dims))): # if the initial dim is too large, split the dim to separate size axes, if possible - dim_idx, dim_max = size_dims[size_idx][0] - assert isinstance(dim_max, int), "variable shape too large for size" - for factor in range(2, int(dim_max**0.5)+1): - if dim_max % factor == 0 and dim_max // factor <= max_sz: - size_dims = size_dims[:size_idx] + [[(dim_idx, dim_max//factor)], [(dim_idx, factor)]] + size_dims[size_idx+1:] + dim_idx, dim, dim_max = size_dims[size_idx][0] + if dim_max <= (max_sz:=max_sizes[size_idx]): continue + assert isinstance(dim, int), "variable shape too large for size" + for factor in range(2, int(dim**0.5)+1): + if dim % factor == 0 and dim // factor <= max_sz: + size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:] break - assert size_dims[size_idx][0][1] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim_max} > {max_sz}" + assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}" # compress the extra dims, collapsing them onto the left-most valid size axis - # for run_process_replay, collapse onto the right-most dim to compare the outputs. TODO: remove - if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1] cur_size_idx = 0 while len(size_dims) > len(max_sizes): - if prod([dim_max for (_, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][1] < max_sizes[cur_size_idx]: + if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]: size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:] elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1 else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}") - if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1] # construct the final dim idx variables from the the portions of the size variables - sizes, idxs = [prod([dim_max for (_, dim_max) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims) + sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims) size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)] for size_idx, size_var in enumerate(size_vars): - for dim_idx, dim_max in size_dims[size_idx]: - idxs[dim_idx] += (size_var % dim_max) * (idxs[dim_idx].max+1) - size_var //= dim_max + for dim_idx, dim, _ in size_dims[size_idx]: + idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1) + size_var //= dim # pad the final sizes array to the proper length if necessary return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))