hotfix: add CHECK_NEQ to fuzz_shapetracker_math

This commit is contained in:
George Hotz
2024-01-30 10:07:54 -08:00
parent 09f2952dc3
commit d8f6280ffb

View File

@@ -1,4 +1,5 @@
import random
from typing import Tuple
from tqdm import trange
from tinygrad.helpers import getenv, DEBUG, colored
from tinygrad.shape.shapetracker import ShapeTracker
@@ -6,7 +7,7 @@ from test.external.fuzz_shapetracker import shapetracker_ops
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
from test.unit.test_shapetracker_math import st_equal, MultiShapeTracker
def fuzz_plus():
def fuzz_plus() -> Tuple[ShapeTracker, ShapeTracker]:
m = MultiShapeTracker([ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))])
for _ in range(4): random.choice(shapetracker_ops)(m)
backup = m.sts[0]
@@ -18,7 +19,7 @@ def fuzz_plus():
# shrink and expand aren't invertible, and stride is only invertible in the flip case
invertible_shapetracker_ops = [do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad]
def fuzz_invert():
def fuzz_invert() -> Tuple[ShapeTracker, ShapeTracker]:
start = ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))
m = MultiShapeTracker([start])
for _ in range(8): random.choice(invertible_shapetracker_ops)(m)
@@ -33,6 +34,10 @@ if __name__ == "__main__":
for _ in trange(total, desc=f"{fuzz}"):
st1, st2 = fuzz()
eq = st_equal(st1, st2)
if getenv("CHECK_NEQ") and eq and st1.simplify() != st2.simplify():
print(colored("same but unequal", "yellow"))
print(st1.simplify())
print(st2.simplify())
if DEBUG >= 1:
print(f"EXP: {st1}")
print(f"GOT: {st2}")