diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py index 8324e29a8c..b18ff6c1ac 100644 --- a/test/null/test_gpudims.py +++ b/test/null/test_gpudims.py @@ -66,6 +66,8 @@ class TestGroupedDims(unittest.TestCase): self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False) # test when the only divisor is the square root of dim self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) + # 2 -> 3 + self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) # collapse on onto the left most axis self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) @@ -85,8 +87,11 @@ class TestGroupedDims(unittest.TestCase): with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) - def test_split_2d_to_3d(self): - self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) + def test_grouped_direct_dims_are_special(self): + # when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod) + idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False) + assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}" + assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}" def test_max_sizes_none(self): self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])