example test use z3 to verify valid simplification (#9684)

This commit is contained in:
chenyu
2025-04-02 01:05:52 -04:00
committed by GitHub
parent bca0c85193
commit c20f112e9f
2 changed files with 29 additions and 1 deletions

View File

@@ -48,7 +48,8 @@ setup(name='tinygrad',
'testing_unit': testing_minimal + [
"tqdm",
"safetensors",
"tabulate" # for sz.py
"z3-solver",
"tabulate", # for sz.py
],
'testing': testing_minimal + [
"pillow",

View File

@@ -138,6 +138,33 @@ class TestValidIdxSimplification(unittest.TestCase):
"(ridx0*1568)",
"((ridx2<1)&(ridx1<6))")
def test_valid_becomes_const1_z3(self):
from z3 import Ints, Solver, And, If, Not, unsat
ridx0, ridx1, ridx2, alu11, alu15 = Ints('ridx0 ridx1 ridx2 alu11 alu15')
alu11 = (ridx1+ridx2)
alu15 = ((alu11+1)/7)
idx = (alu15*-31)+(((((alu11+218)/224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = If(valid, idx, 0)
# correct simplification
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
simplifed_idx = (ridx0*1568)
simplifed_load = If(valid, simplifed_idx, 0)
s.add(Not(load == simplifed_load)) # Check if they are NOT equivalent
assert s.check() == unsat, f"The expressions are not equivalent. {s.model()=}"
# new solver for a wrong simplified expression
s = Solver()
s.add(And(0<=ridx0, ridx0<30, 0<=ridx1, ridx1<7, 0<=ridx2, ridx2<2))
wrong_simplifed_idx = (ridx0*1567)+ridx1
wrong_simplifed_load = If(valid, wrong_simplifed_idx, 0)
s.add(Not(load == wrong_simplifed_load)) # Check if they are NOT equivalent
assert s.check() != unsat, "The expressions are equivalent??"
print("The expressions are not equivalent.")
print(s.model())
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]