mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
* render z3 range better * working version * rename * add to workflow * factor out variable_names * smaller expressions * smaller * + back
67 lines
2.7 KiB
Python
67 lines
2.7 KiB
Python
import random
|
|
import z3
|
|
from tinygrad.uop.ops import UOp, Ops
|
|
from tinygrad.uop.validate import uops_to_z3
|
|
from tinygrad.helpers import DEBUG, Context, colored
|
|
|
|
seed = random.randint(0, 100)
|
|
print(f"Seed: {seed}")
|
|
random.seed(seed)
|
|
|
|
def get_random_term(ranges, factors):
|
|
# 10% chance of nesting
|
|
if random.randint(0,9) == 0: return get_random_expr(ranges, factors)
|
|
return random.choice(ranges)*random.choice(factors)*random.choice([1, 1, 1, -1])
|
|
|
|
def get_random_expr(ranges, factors):
|
|
num_terms = random.randint(2,4)
|
|
x = UOp.sum(*[get_random_term(ranges, factors) for _ in range(num_terms)])
|
|
return x.alu(random.choice([Ops.IDIV, Ops.MOD]), x.ufix(random.choice(factors)*random.choice([1, 1, 1, -1])))
|
|
|
|
if __name__ == "__main__":
|
|
skipped = 0
|
|
for i in range(700):
|
|
if i % 100 == 0:
|
|
print(f"Running test {i}")
|
|
upper_bounds = [*list(range(1, 4)), 16, 33, 53, 64, 256]
|
|
variable_names = ["i", "j", "k"]
|
|
variables = [UOp.variable(s, 1, random.choice(upper_bounds)) for s in variable_names]
|
|
factors = variables+upper_bounds
|
|
# add some products
|
|
for _ in range(2): factors.append(random.choice(variables)*random.choice(variables))
|
|
# add some adds
|
|
for _ in range(2): factors.append(random.choice(variables)+random.choice(factors))
|
|
num_ranges = 4
|
|
ranges = [UOp.range(random.choice(factors), i) for i in range(num_ranges)]
|
|
variable_names += [f"r{i}" for i in range(num_ranges)]
|
|
expr = get_random_expr(ranges, factors)
|
|
|
|
with Context(CORRECT_DIVMOD_FOLDING=1):
|
|
simplified_expr = expr.simplify()
|
|
|
|
if DEBUG>=1:
|
|
print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False))
|
|
|
|
solver = z3.Solver()
|
|
solver.set(timeout=3000) # some expressions take very long verify, but its very unlikely they actually return sat
|
|
z3_expr, z3_simplified_expr, *z3_vars = uops_to_z3(solver, expr, simplified_expr, *variables, *ranges)
|
|
check = solver.check(z3_simplified_expr != z3_expr)
|
|
if check == z3.unknown and DEBUG>=1:
|
|
skipped += 1
|
|
print("skipped z3 verification due to timeout")
|
|
elif check == z3.sat:
|
|
print(colored("simplify INCORRECT!", "red"))
|
|
print(solver.model())
|
|
var_vals = {s:solver.model()[z] for s,z in zip(variable_names, z3_vars)}
|
|
print("reproduce with:")
|
|
print("var_vals = ", var_vals)
|
|
print("globals = var_vals|{'cdiv':cdiv,'cmod':cmod}")
|
|
print("expr = ast.simplify()")
|
|
print("assert eval(ast.render(pm=renderer_infer, simplify=False),globals) == eval(expr.render(pm=renderer_infer, simplify=False),globals)")
|
|
print()
|
|
|
|
assert False
|
|
|
|
if DEBUG >= 2: print(f"validated {expr.render()}")
|
|
print(f"Skipped {skipped} expressions due to timeout")
|