From b9744ab62b2e2a84a0bcfce3c41cd24e8bca7181 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Feb 2026 18:18:44 -0500 Subject: [PATCH] one more test_gpudims test (#14898) failure from the bad simplification attempt --- test/null/test_gpudims.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py index 8324e29a8c..b18ff6c1ac 100644 --- a/test/null/test_gpudims.py +++ b/test/null/test_gpudims.py @@ -66,6 +66,8 @@ class TestGroupedDims(unittest.TestCase): self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False) # test when the only divisor is the square root of dim self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False) + # 2 -> 3 + self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) # collapse on onto the left most axis self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5]) @@ -85,8 +87,11 @@ class TestGroupedDims(unittest.TestCase): with self.assertRaises(RuntimeError): get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16)) - def test_split_2d_to_3d(self): - self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False) + def test_grouped_direct_dims_are_special(self): + # when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod) + idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False) + assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}" + assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}" def test_max_sizes_none(self): self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])