mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
linearizer: change globals to merge into left axis/gridDims.x first (#5033)
* linearizer: change order of collapse to be left-most also fixes Variable max size to be correct and add docs for the off parameter * fix multiple global dim oversizes * add passing variable test and reorganize tests * use assert RuntimeError for failing test
This commit is contained in:
@@ -608,26 +608,46 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
|
||||
assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
|
||||
|
||||
# pad sizes with ones if necessary
|
||||
_assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2,1,1])
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3,1])
|
||||
|
||||
# check reverse dims
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2,1])
|
||||
_assert_grouped_dims("gidx", (2,3,4,), (16,16,16,), False, [2,3,4])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2]) # this is the new linearizer way
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2]) # this is the new linearizer way
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,4,6]) # this is the old linearizer way - TODO: remove
|
||||
|
||||
# test splitting globals
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12])
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4])
|
||||
_assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24])
|
||||
|
||||
with self.assertRaises(AssertionError): # dim too large and not factorable
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2])
|
||||
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2])
|
||||
|
||||
# TODO: support sint collapse
|
||||
with self.assertRaises(RuntimeError):
|
||||
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,)
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,)
|
||||
|
||||
with self.assertRaises(AssertionError): # too large for sizes
|
||||
# too large for sizes
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,)
|
||||
|
||||
with self.assertRaises(AssertionError): # variable too large
|
||||
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 17),3,4), (16,16,16,), False,)
|
||||
# variable too large
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
|
||||
@@ -22,6 +22,7 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio
|
||||
|
||||
Keyword arguments:
|
||||
prefix -- the prefix to use for the size Variable names.
|
||||
off -- the starting index for the size Variable names.
|
||||
dims -- the global or local dims of the full shape.
|
||||
max_sizes -- the maximum values for each size in (x, y, z) order.
|
||||
reverse_dims -- reverse the order of the dims as they are mapped into size, i.e. if True, the right dim will go to the left size (.x).
|
||||
@@ -34,41 +35,39 @@ def get_grouped_dims(prefix:str, off:int, dims:Tuple[sint, ...], max_sizes:Optio
|
||||
|
||||
# initialize the map of dims to size with a single dim in each size axis
|
||||
# TODO: support sint properly
|
||||
size_dims:List[List[Tuple[int, sint]]] = [[(dim_idx, dim)] for dim_idx, dim in enumerate(dims)]
|
||||
size_dims:List[List[Tuple[int, sint, sint]]] = [[(dim_idx, dim, dim if isinstance(dim, int) else dim.max+1)] for dim_idx, dim in enumerate(dims)]
|
||||
|
||||
# reverse the order of the dims to size map, if desired (currently for globals where smallest stride is on the right)
|
||||
# TODO: remove reverse_dims, the mapping of dims to size for globals should be cosearched with memory layouts for optimal peformance
|
||||
if reverse_dims: size_dims = size_dims[::-1]
|
||||
|
||||
# ensure that the initial dims initially fit the valid size axes
|
||||
for size_idx, max_sz in [(i, sz) for i, sz in enumerate(max_sizes[:len(size_dims)]) if size_dims[i][0][1] > sz]:
|
||||
for size_idx in range(min(len(max_sizes), len(size_dims))):
|
||||
# if the initial dim is too large, split the dim to separate size axes, if possible
|
||||
dim_idx, dim_max = size_dims[size_idx][0]
|
||||
assert isinstance(dim_max, int), "variable shape too large for size"
|
||||
for factor in range(2, int(dim_max**0.5)+1):
|
||||
if dim_max % factor == 0 and dim_max // factor <= max_sz:
|
||||
size_dims = size_dims[:size_idx] + [[(dim_idx, dim_max//factor)], [(dim_idx, factor)]] + size_dims[size_idx+1:]
|
||||
dim_idx, dim, dim_max = size_dims[size_idx][0]
|
||||
if dim_max <= (max_sz:=max_sizes[size_idx]): continue
|
||||
assert isinstance(dim, int), "variable shape too large for size"
|
||||
for factor in range(2, int(dim**0.5)+1):
|
||||
if dim % factor == 0 and dim // factor <= max_sz:
|
||||
size_dims = size_dims[:size_idx] + [[(dim_idx, dim//factor, dim//factor)], [(dim_idx, factor, factor)]] + size_dims[size_idx+1:]
|
||||
break
|
||||
assert size_dims[size_idx][0][1] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim_max} > {max_sz}"
|
||||
assert size_dims[size_idx][0][2] <= max_sz, f"dim at {size_idx} too large and non-factorable: {dim} > {max_sz}"
|
||||
|
||||
# compress the extra dims, collapsing them onto the left-most valid size axis
|
||||
# for run_process_replay, collapse onto the right-most dim to compare the outputs. TODO: remove
|
||||
if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1]
|
||||
cur_size_idx = 0
|
||||
while len(size_dims) > len(max_sizes):
|
||||
if prod([dim_max for (_, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][1] < max_sizes[cur_size_idx]:
|
||||
if prod([dim_max for (_, _, dim_max) in size_dims[cur_size_idx]])*size_dims[cur_size_idx+1][0][2] <= max_sizes[cur_size_idx]:
|
||||
size_dims = size_dims[:cur_size_idx] + [size_dims[cur_size_idx] + size_dims[cur_size_idx+1]] + size_dims[cur_size_idx+2:]
|
||||
elif cur_size_idx < len(max_sizes)-1: cur_size_idx += 1
|
||||
else: raise AssertionError(f"cannot fit dims in size: {dims=} {max_sizes=}")
|
||||
if reverse_dims: size_dims, max_sizes = size_dims[::-1], max_sizes[::-1]
|
||||
|
||||
# construct the final dim idx variables from the the portions of the size variables
|
||||
sizes, idxs = [prod([dim_max for (_, dim_max) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
|
||||
sizes, idxs = [prod([dim for (_, dim, _) in size_dim]) for size_dim in size_dims], [NumNode(0)] * len(dims)
|
||||
size_vars = loop_idxs = [Variable(f"{prefix}{len(sizes)-1-(i+off) if reverse_dims else i+off}", 0, s-1) for i,s in enumerate(sizes)]
|
||||
for size_idx, size_var in enumerate(size_vars):
|
||||
for dim_idx, dim_max in size_dims[size_idx]:
|
||||
idxs[dim_idx] += (size_var % dim_max) * (idxs[dim_idx].max+1)
|
||||
size_var //= dim_max
|
||||
for dim_idx, dim, _ in size_dims[size_idx]:
|
||||
idxs[dim_idx] += (size_var % dim) * (idxs[dim_idx].max+1)
|
||||
size_var //= dim
|
||||
|
||||
# pad the final sizes array to the proper length if necessary
|
||||
return idxs, [x for x in loop_idxs if not isinstance(x, NumNode)], sizes + [1]*(len(max_sizes)-len(sizes))
|
||||
|
||||
Reference in New Issue
Block a user