update test_gpudims to prove bijectivity (#14895)

* update test_gpudims to prove bijectivity

* one more
This commit is contained in:
chenyu
2026-02-19 16:18:59 -05:00
committed by GitHub
parent 19ce7a3f7f
commit 2b31823ef9

View File

@@ -19,16 +19,22 @@ class TestGroupedDims(unittest.TestCase):
self._verify_indices_z3(idxs, dims)
def _verify_indices_z3(self, idxs, dims):
"""Use z3 to prove 0 <= flat < total for the returned indices.
NOTE: no injectivity check — z3 is too slow on nested div/mod expressions (e.g. reverse+split takes ~4s)."""
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
total = math.prod(dims)
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
# build flat index and primed flat (same expression with renamed SPECIALs)
flat = UOp.const(dtypes.index, 0)
for i, idx in enumerate(idxs):
flat = flat + idx * int(math.prod(dims[i+1:]))
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
solver = z3.Solver()
[z3_flat] = uops_to_z3(solver, flat)
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
# bounds
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
# injectivity: flat == flat' but inputs differ => unsat
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
def test_grouped_dims(self):
# no-op
@@ -45,6 +51,7 @@ class TestGroupedDims(unittest.TestCase):
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
# prefer group_dim strategy when possible
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])