mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Symbolic divisor fuzzer (#13433)
* render z3 range better * working version * rename * add to workflow * factor out variable_names * smaller expressions * smaller * + back
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -326,6 +326,8 @@ jobs:
|
||||
deps: testing_unit
|
||||
- name: Fuzz Test symbolic
|
||||
run: python test/external/fuzz_symbolic.py
|
||||
- name: Fuzz Test symbolic (symbolic divisors)
|
||||
run: python test/external/fuzz_symbolic_symbolic_div.py
|
||||
- name: Fuzz Test fast idiv
|
||||
run: python test/external/fuzz_fast_idiv.py
|
||||
- name: Fuzz Test shape ops
|
||||
|
||||
66
test/external/fuzz_symbolic_symbolic_div.py
vendored
Normal file
66
test/external/fuzz_symbolic_symbolic_div.py
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
import random
|
||||
import z3
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
from tinygrad.helpers import DEBUG, Context, colored
|
||||
|
||||
seed = random.randint(0, 100)
|
||||
print(f"Seed: {seed}")
|
||||
random.seed(seed)
|
||||
|
||||
def get_random_term(ranges, factors):
|
||||
# 10% chance of nesting
|
||||
if random.randint(0,9) == 0: return get_random_expr(ranges, factors)
|
||||
return random.choice(ranges)*random.choice(factors)*random.choice([1, 1, 1, -1])
|
||||
|
||||
def get_random_expr(ranges, factors):
|
||||
num_terms = random.randint(2,4)
|
||||
x = UOp.sum(*[get_random_term(ranges, factors) for _ in range(num_terms)])
|
||||
return x.alu(random.choice([Ops.IDIV, Ops.MOD]), x.ufix(random.choice(factors)*random.choice([1, 1, 1, -1])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
skipped = 0
|
||||
for i in range(700):
|
||||
if i % 100 == 0:
|
||||
print(f"Running test {i}")
|
||||
upper_bounds = [*list(range(1, 4)), 16, 33, 53, 64, 256]
|
||||
variable_names = ["i", "j", "k"]
|
||||
variables = [UOp.variable(s, 1, random.choice(upper_bounds)) for s in variable_names]
|
||||
factors = variables+upper_bounds
|
||||
# add some products
|
||||
for _ in range(2): factors.append(random.choice(variables)*random.choice(variables))
|
||||
# add some adds
|
||||
for _ in range(2): factors.append(random.choice(variables)+random.choice(factors))
|
||||
num_ranges = 4
|
||||
ranges = [UOp.range(random.choice(factors), i) for i in range(num_ranges)]
|
||||
variable_names += [f"r{i}" for i in range(num_ranges)]
|
||||
expr = get_random_expr(ranges, factors)
|
||||
|
||||
with Context(CORRECT_DIVMOD_FOLDING=1):
|
||||
simplified_expr = expr.simplify()
|
||||
|
||||
if DEBUG>=1:
|
||||
print(expr.render(simplify=False), " --> ", simplified_expr.render(simplify=False))
|
||||
|
||||
solver = z3.Solver()
|
||||
solver.set(timeout=3000) # some expressions take very long verify, but its very unlikely they actually return sat
|
||||
z3_expr, z3_simplified_expr, *z3_vars = uops_to_z3(solver, expr, simplified_expr, *variables, *ranges)
|
||||
check = solver.check(z3_simplified_expr != z3_expr)
|
||||
if check == z3.unknown and DEBUG>=1:
|
||||
skipped += 1
|
||||
print("skipped z3 verification due to timeout")
|
||||
elif check == z3.sat:
|
||||
print(colored("simplify INCORRECT!", "red"))
|
||||
print(solver.model())
|
||||
var_vals = {s:solver.model()[z] for s,z in zip(variable_names, z3_vars)}
|
||||
print("reproduce with:")
|
||||
print("var_vals = ", var_vals)
|
||||
print("globals = var_vals|{'cdiv':cdiv,'cmod':cmod}")
|
||||
print("expr = ast.simplify()")
|
||||
print("assert eval(ast.render(pm=renderer_infer, simplify=False),globals) == eval(expr.render(pm=renderer_infer, simplify=False),globals)")
|
||||
print()
|
||||
|
||||
assert False
|
||||
|
||||
if DEBUG >= 2: print(f"validated {expr.render()}")
|
||||
print(f"Skipped {skipped} expressions due to timeout")
|
||||
@@ -25,7 +25,7 @@ try:
|
||||
# variables
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda x,ctx: create_bounded(x.arg, 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(f"r{x.arg}", 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: create_bounded(x.render(simplify=False), 0, ctx[1][x.src[0]]-1, ctx[0])),
|
||||
# loads are variables bounded by the min/max of the dtype
|
||||
(UPat(Ops.LOAD, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: create_bounded(f"load{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
(UPat(Ops.LOAD, dtypes.bool, name="x"), lambda x,ctx: (z3.Bool(f"load{len(ctx[1])}", ctx=ctx[0].ctx), None)),
|
||||
|
||||
Reference in New Issue
Block a user