From 2b31823ef9f178a2e9716a6a992fdbb7fdb8b792 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Feb 2026 16:18:59 -0500 Subject: [PATCH] update test_gpudims to prove bijectivity (#14895) * update test_gpudims to prove bijectivity * one more --- test/null/test_gpudims.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/null/test_gpudims.py b/test/null/test_gpudims.py index 50492e0592..e84b5ec752 100644 --- a/test/null/test_gpudims.py +++ b/test/null/test_gpudims.py @@ -19,16 +19,22 @@ class TestGroupedDims(unittest.TestCase): self._verify_indices_z3(idxs, dims) def _verify_indices_z3(self, idxs, dims): - """Use z3 to prove 0 <= flat < total for the returned indices. - NOTE: no injectivity check — z3 is too slow on nested div/mod expressions (e.g. reverse+split takes ~4s).""" + """Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat).""" total = math.prod(dims) + specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg) + # build flat index and primed flat (same expression with renamed SPECIALs) flat = UOp.const(dtypes.index, 0) for i, idx in enumerate(idxs): flat = flat + idx * int(math.prod(dims[i+1:])) + flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials}) solver = z3.Solver() - [z3_flat] = uops_to_z3(solver, flat) + [z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p) + # bounds self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}") self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}") + # injectivity: flat == flat' but inputs differ => unsat + inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials]) + self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}") def test_grouped_dims(self): # no-op @@ -45,6 +51,7 @@ class TestGroupedDims(unittest.TestCase): self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16]) self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32]) self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256]) + self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14]) # prefer group_dim strategy when possible self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])