mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update test_gpudims to prove bijectivity (#14895)
* update test_gpudims to prove bijectivity * one more
This commit is contained in:
@@ -19,16 +19,22 @@ class TestGroupedDims(unittest.TestCase):
|
||||
self._verify_indices_z3(idxs, dims)
|
||||
|
||||
def _verify_indices_z3(self, idxs, dims):
|
||||
"""Use z3 to prove 0 <= flat < total for the returned indices.
|
||||
NOTE: no injectivity check — z3 is too slow on nested div/mod expressions (e.g. reverse+split takes ~4s)."""
|
||||
"""Use z3 to prove bijectivity: bounds (0 <= flat < total) + injectivity (different inputs => different flat)."""
|
||||
total = math.prod(dims)
|
||||
specials = sorted(dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs])), key=lambda u: u.arg)
|
||||
# build flat index and primed flat (same expression with renamed SPECIALs)
|
||||
flat = UOp.const(dtypes.index, 0)
|
||||
for i, idx in enumerate(idxs):
|
||||
flat = flat + idx * int(math.prod(dims[i+1:]))
|
||||
flat_p = flat.substitute({s: UOp(Ops.SPECIAL, s.dtype, s.src, s.arg+"_p") for s in specials})
|
||||
solver = z3.Solver()
|
||||
[z3_flat] = uops_to_z3(solver, flat)
|
||||
[z3_flat, z3_flat_p] = uops_to_z3(solver, flat, flat_p)
|
||||
# bounds
|
||||
self.assertEqual(solver.check(z3_flat < 0), z3.unsat, f"flat can be negative: {dims=}")
|
||||
self.assertEqual(solver.check(z3_flat >= total), z3.unsat, f"flat can be >= {total}: {dims=}")
|
||||
# injectivity: flat == flat' but inputs differ => unsat
|
||||
inputs_differ = z3.Or(*[z3.Int(s.arg) != z3.Int(s.arg+"_p") for s in specials])
|
||||
self.assertEqual(solver.check(z3.And(z3_flat == z3_flat_p, inputs_differ)), z3.unsat, f"not injective: {dims=}")
|
||||
|
||||
def test_grouped_dims(self):
|
||||
# no-op
|
||||
@@ -45,6 +51,7 @@ class TestGroupedDims(unittest.TestCase):
|
||||
self._check_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
self._check_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
self._check_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
self._check_grouped_dims("gidx", (5,12,7), (8,4,16), False, [10,3,14])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
self._check_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
Reference in New Issue
Block a user