mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
linearizer: change globals to merge into left axis/gridDims.x first (#5033)
* linearizer: change order of collapse to be left-most also fixes Variable max size to be correct and add docs for the off parameter * fix multiple global dim oversizes * add passing variable test and reorganize tests * use assert RuntimeError for failing test
This commit is contained in:
@@ -608,26 +608,46 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
|
||||
assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
|
||||
|
||||
# pad sizes with ones if necessary
|
||||
_assert_grouped_dims("gidx", (2,), (16,16,16,), False, [2,1,1])
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16,), False, [2,3,1])
|
||||
|
||||
# check reverse dims
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16,), True, [3,2,1])
|
||||
_assert_grouped_dims("gidx", (2,3,4,), (16,16,16,), False, [2,3,4])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2]) # this is the new linearizer way
|
||||
# _assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2]) # this is the new linearizer way
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,4,6]) # this is the old linearizer way - TODO: remove
|
||||
|
||||
# test splitting globals
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,4,16,), False, [16,4,12])
|
||||
_assert_grouped_dims("gidx", (64,3,4,), (16,16,16,), True, [12,16,4])
|
||||
_assert_grouped_dims("gidx", (128,3,4,), (16,4,256,), False, [16,4,24])
|
||||
|
||||
with self.assertRaises(AssertionError): # dim too large and not factorable
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), False, [6,4,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (32,16,16,), True, [20,3,2])
|
||||
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (32,16,16,), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (4,16,16,), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5,), (16,16,16,), True, [5,12,2])
|
||||
|
||||
# TODO: support sint collapse
|
||||
with self.assertRaises(RuntimeError):
|
||||
_assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5,), (16,16,16,), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (23,), (16,16,16,), False,)
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (128,3,4), (16,4,23,), False,)
|
||||
|
||||
with self.assertRaises(AssertionError): # too large for sizes
|
||||
# too large for sizes
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (2,3,4,5,6), (16,16,16,), False,)
|
||||
|
||||
with self.assertRaises(AssertionError): # variable too large
|
||||
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 17),3,4), (16,16,16,), False,)
|
||||
# variable too large
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
|
||||
Reference in New Issue
Block a user