one more test_gpudims test (#14898)

failure from the bad simplification attempt
This commit is contained in:
chenyu
2026-02-19 18:18:44 -05:00
committed by GitHub
parent 9d6cf00be2
commit b9744ab62b

View File

@@ -66,6 +66,8 @@ class TestGroupedDims(unittest.TestCase):
self._check_grouped_dims("gidx", (65536,2), (65535,65535,65535), False, [32768,4], False)
# test when the only divisor is the square root of dim
self._check_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
# 2 -> 3
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
# collapse on onto the left most axis
self._check_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
@@ -85,8 +87,11 @@ class TestGroupedDims(unittest.TestCase):
with self.assertRaises(RuntimeError):
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
def test_split_2d_to_3d(self):
self._check_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
def test_grouped_direct_dims_are_special(self):
# when (2,3) are merged into 6, the unmerged dims (4,5) should map directly to SPECIAL ops (no div/mod)
idxs = get_grouped_dims("gidx", (2,3,4,5), (16,16,16), False)
assert idxs[2].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[2].op}"
assert idxs[3].op is Ops.SPECIAL, f"expected SPECIAL for direct-mapped dim, got {idxs[3].op}"
def test_max_sizes_none(self):
self._check_grouped_dims("gidx", (2,3,4), None, False, [2,3,4])