Files
tinygrad/test/external/fuzz_shapetracker_math.py
George Hotz ca59054463 fix shapetracker math (#2861)
* proper test

* all st math good now

* fix real_strides bug
2023-12-19 22:17:34 -08:00

41 lines
1.7 KiB
Python

import random
from tqdm import trange
from tinygrad.helpers import getenv, DEBUG, colored
from tinygrad.shape.shapetracker import ShapeTracker
from test.external.fuzz_shapetracker import shapetracker_ops
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
from test.unit.test_shapetracker_math import st_equal, MultiShapeTracker
def fuzz_plus():
m = MultiShapeTracker([ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))])
for _ in range(4): random.choice(shapetracker_ops)(m)
backup = m.sts[0]
m.sts.append(ShapeTracker.from_shape(m.sts[0].shape))
for _ in range(4): random.choice(shapetracker_ops)(m)
st_sum = backup + m.sts[1]
return m.sts[0], st_sum
# shrink and expand aren't invertible, and stride is only invertible in the flip case
invertible_shapetracker_ops = [do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad]
def fuzz_invert():
start = ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))
m = MultiShapeTracker([start])
for _ in range(8): random.choice(invertible_shapetracker_ops)(m)
inv = m.sts[0].invert(start.shape)
st_sum = (m.sts[0] + inv) if inv else None
return start, st_sum
if __name__ == "__main__":
# random.seed(42)
total = getenv("CNT", 1000)
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
for _ in trange(total, desc=f"{fuzz}"):
st1, st2 = fuzz()
eq = st_equal(st1, st2)
if DEBUG >= 1:
print(f"EXP: {st1}")
print(f"GOT: {st2}")
print(colored("****", "green" if eq else "red"))
if not eq: exit(0)