mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix shapetracker math (#2861)
* proper test * all st math good now * fix real_strides bug
This commit is contained in:
27
test/external/fuzz_shapetracker_math.py
vendored
27
test/external/fuzz_shapetracker_math.py
vendored
@@ -1,21 +1,10 @@
|
||||
import random
|
||||
from typing import List
|
||||
from tqdm import trange
|
||||
from tinygrad.helpers import getenv, DEBUG, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from test.external.fuzz_shapetracker import shapetracker_ops
|
||||
from test.external.fuzz_shapetracker import do_permute, do_reshape_split_one, do_reshape_combine_two, do_flip, do_pad
|
||||
|
||||
class MultiShapeTracker:
|
||||
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
|
||||
@property
|
||||
def shape(self): return self.sts[0].shape
|
||||
def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts]
|
||||
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
||||
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
||||
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
||||
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
||||
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
||||
from test.unit.test_shapetracker_math import st_equal, MultiShapeTracker
|
||||
|
||||
def fuzz_plus():
|
||||
m = MultiShapeTracker([ShapeTracker.from_shape((random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)))])
|
||||
@@ -34,20 +23,18 @@ def fuzz_invert():
|
||||
m = MultiShapeTracker([start])
|
||||
for _ in range(8): random.choice(invertible_shapetracker_ops)(m)
|
||||
inv = m.sts[0].invert(start.shape)
|
||||
st_sum = (ShapeTracker.from_shape(m.sts[0].shape) + inv) if inv else None
|
||||
st_sum = (m.sts[0] + inv) if inv else None
|
||||
return start, st_sum
|
||||
|
||||
if __name__ == "__main__":
|
||||
# random.seed(42)
|
||||
total = getenv("CNT", 1000)
|
||||
for fuzz in [globals()[f'fuzz_{x}'] for x in getenv("FUZZ", "invert,plus").split(",")]:
|
||||
good = 0
|
||||
for _ in trange(total):
|
||||
for _ in trange(total, desc=f"{fuzz}"):
|
||||
st1, st2 = fuzz()
|
||||
if st1 == st2: good += 1
|
||||
if st1 != st2 or DEBUG >= 1:
|
||||
eq = st_equal(st1, st2)
|
||||
if DEBUG >= 1:
|
||||
print(f"EXP: {st1}")
|
||||
print(f"GOT: {st2}")
|
||||
print(colored("****", "red" if st1 != st2 else "green"))
|
||||
print(f"hit {good}/{total}")
|
||||
assert good == total
|
||||
print(colored("****", "green" if eq else "red"))
|
||||
if not eq: exit(0)
|
||||
|
||||
@@ -1,5 +1,38 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
class MultiShapeTracker:
|
||||
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
|
||||
@property
|
||||
def shape(self): return self.sts[0].shape
|
||||
def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts]
|
||||
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
||||
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
||||
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
||||
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
||||
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
||||
|
||||
def st_equal(st1, st2) -> bool:
|
||||
if st1.shape != st2.shape: return False
|
||||
if st1 == st2: return True
|
||||
idx = Variable("idx", 0, prod(st1.shape)-1)
|
||||
st1_idx, st1_valid = st1.expr_node(idx)
|
||||
st2_idx, st2_valid = st2.expr_node(idx)
|
||||
for i in range(idx.min, idx.max):
|
||||
st1_off = sym_infer(st1_idx, {idx: i})
|
||||
st2_off = sym_infer(st2_idx, {idx: i})
|
||||
st1_v = sym_infer(st1_valid, {idx: i})
|
||||
st2_v = sym_infer(st2_valid, {idx: i})
|
||||
if st1_v != st2_v or (st1_off != st2_off and st1_v):
|
||||
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
|
||||
print(st1)
|
||||
print(st2)
|
||||
return False
|
||||
return True
|
||||
|
||||
class TestShapeTrackerBasics(unittest.TestCase):
|
||||
def test_pad_shrink_removes_mask(self):
|
||||
@@ -22,6 +55,11 @@ class TestShapeTrackerBasics(unittest.TestCase):
|
||||
x1 = x1.reshape( (2, 2, 5) )
|
||||
assert x == x1.simplify()
|
||||
|
||||
def test_simplify_is_correct(self):
|
||||
multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False),
|
||||
View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False)))
|
||||
assert st_equal(multiv, multiv.simplify())
|
||||
|
||||
class TestShapeTrackerAdd(unittest.TestCase):
|
||||
def test_simple_add_reshape(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
@@ -36,6 +74,16 @@ class TestShapeTrackerAdd(unittest.TestCase):
|
||||
b = b.permute((1,0))
|
||||
assert a+b == ShapeTracker.from_shape((10, 10))
|
||||
|
||||
def test_plus_real1(self):
|
||||
st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))])
|
||||
st.shrink( ((0, 15), (6, 9)) )
|
||||
backup = st.sts[0]
|
||||
st.sts.append(ShapeTracker.from_shape(backup.shape))
|
||||
st.reshape( (45,) )
|
||||
st.stride( (4,) )
|
||||
st.reshape( (4, 3) )
|
||||
assert st_equal(backup + st.sts[1], st.sts[0])
|
||||
|
||||
class TestShapeTrackerInvert(unittest.TestCase):
|
||||
def test_invert_reshape(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
@@ -46,13 +94,20 @@ class TestShapeTrackerInvert(unittest.TestCase):
|
||||
def test_invert_permute(self):
|
||||
a = ShapeTracker.from_shape((5, 20))
|
||||
x = a.permute((1,0))
|
||||
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
||||
ap = x + x.invert(a.shape)
|
||||
assert ap == a, f"{ap} != {a}"
|
||||
|
||||
def test_invert_permute_3(self):
|
||||
a = ShapeTracker.from_shape((8, 4, 5))
|
||||
x = a.permute((1,2,0))
|
||||
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
||||
ap = x + x.invert(a.shape)
|
||||
assert ap == a, f"{ap} != {a}"
|
||||
|
||||
def test_invert_real1(self):
|
||||
a = ShapeTracker.from_shape((3, 6, 10))
|
||||
x = a.reshape( (3, 3, 2, 10) )
|
||||
x = x.permute( (2, 1, 3, 0) )
|
||||
ap = x + x.invert(a.shape)
|
||||
assert ap == a, f"{ap} != {a}"
|
||||
|
||||
def test_cant_invert_expand(self):
|
||||
@@ -66,10 +121,17 @@ class TestShapeTrackerInvert(unittest.TestCase):
|
||||
assert x.invert(a.shape) is None
|
||||
|
||||
def test_can_invert_flip(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.stride((-1,1))
|
||||
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
||||
assert ap == a, f"{ap} != {a}"
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_can_invert_flip_permute(self):
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.permute((1,0))
|
||||
x = x.stride((-1,1))
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_cant_invert_stride(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
@@ -81,8 +143,8 @@ class TestShapeTrackerInvert(unittest.TestCase):
|
||||
x = a.pad( ((2, 0), (0, 0)) )
|
||||
x = x.reshape( (2, 2, 5) )
|
||||
x = x.reshape( (4, 5) )
|
||||
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
||||
assert ap == a, f"{ap} != {a}"
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user