mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add Fuzz Test symbolic / shapetracker to CI. (#1278)
* Fuzz test symbolic and shapetracker This reverts commit d5773ddebff54c1ff608838076f0b4ff126b8aa8. * mess again * no tail * test shapetracker too * Revert mess and enable all tests * removed leftover
This commit is contained in:
41
test/external/fuzz_symbolic.py
vendored
41
test/external/fuzz_symbolic.py
vendored
@@ -1,5 +1,8 @@
|
||||
import itertools
|
||||
import random
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
random.seed(42)
|
||||
|
||||
def add_v(expr, rng=None):
|
||||
if rng is None: rng = random.randint(0,2)
|
||||
@@ -30,13 +33,14 @@ def ge(expr, rng=None):
|
||||
return expr >= rng, rng
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num]
|
||||
while 1:
|
||||
u1 = Variable("v1", 0, 2)
|
||||
u2 = Variable("v2", 0, 3)
|
||||
u3 = Variable("v3", 0, 4)
|
||||
ops = [add_v, div, mul, add_num, mod]
|
||||
for _ in range(1000):
|
||||
upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256]
|
||||
u1 = Variable("v1", 0, random.choice(upper_bounds))
|
||||
u2 = Variable("v2", 0, random.choice(upper_bounds))
|
||||
u3 = Variable("v3", 0, random.choice(upper_bounds))
|
||||
v = [u1,u2,u3]
|
||||
tape = [random.choice(ops) for _ in range(20)]
|
||||
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
|
||||
# 10% of the time, add a less than or greater than
|
||||
if random.random() < 0.05: tape.append(lt)
|
||||
elif random.random() < 0.05: tape.append(ge)
|
||||
@@ -44,18 +48,15 @@ if __name__ == "__main__":
|
||||
rngs = []
|
||||
for t in tape:
|
||||
expr, rng = t(expr)
|
||||
print(t.__name__, rng)
|
||||
if DEBUG >= 1: print(t.__name__, rng)
|
||||
rngs.append(rng)
|
||||
print(expr)
|
||||
for v1 in range(u1.min, u1.max+1):
|
||||
for v2 in range(u2.min, u2.max+1):
|
||||
for v3 in range(u3.min, u3.max+1):
|
||||
v = [v1,v2,v3]
|
||||
rn = 0
|
||||
for t,r in zip(tape, rngs):
|
||||
rn, _ = t(rn, r)
|
||||
num = eval(expr.render())
|
||||
assert num == rn, f"mismatch at {v1} {v2} {v3}, {num} != {rn}"
|
||||
#print(v1, v2, v3, num, rn)
|
||||
|
||||
|
||||
if DEBUG >=1: print(expr)
|
||||
space = list(itertools.product(range(u1.min, u1.max+1), range(u2.min, u2.max+1), range(u3.min, u3.max+1)))
|
||||
volume = len(space)
|
||||
for (v1, v2, v3) in random.sample(space, min(100, volume)):
|
||||
v = [v1,v2,v3]
|
||||
rn = 0
|
||||
for t,r in zip(tape, rngs): rn, _ = t(rn, r)
|
||||
num = eval(expr.render())
|
||||
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}"
|
||||
if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}")
|
||||
|
||||
Reference in New Issue
Block a user