mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
[bounty] Stride is flip (#8792)
* replace stride with flip * Complete replacing stride with flip clean flip function in view.py fix tests * fix tests for multi shapetracker * fix tests for fuzz shapetracker * fix tests for fuzz shapetracker * debug * debug * fix * fix * fix --------- Co-authored-by: George Hotz <geohot@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
15
test/external/fuzz_shapetracker.py
vendored
15
test/external/fuzz_shapetracker.py
vendored
@@ -38,17 +38,10 @@ def do_shrink(st):
|
||||
if DEBUG >= 1: print("st.shrink(", shrink, ")")
|
||||
st.shrink(shrink)
|
||||
|
||||
def do_stride(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
|
||||
if DEBUG >= 1: print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
|
||||
def do_flip(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
|
||||
if DEBUG >= 1: print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
flip = tuple(random.random() < 0.5 for _ in st.shape)
|
||||
if DEBUG >= 1: print("st.flip(", flip, ")")
|
||||
st.flip(flip)
|
||||
|
||||
def do_expand(st):
|
||||
c = [i for i,s in enumerate(st.shape) if s==1]
|
||||
@@ -58,7 +51,7 @@ def do_expand(st):
|
||||
if DEBUG >= 1: print("st.expand(", expand, ")")
|
||||
st.expand(expand)
|
||||
|
||||
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
|
||||
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_flip, do_expand]
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(42)
|
||||
|
||||
@@ -43,9 +43,9 @@ class CheckingShapeTracker:
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
return self
|
||||
|
||||
def flip(self, axis):
|
||||
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
def flip(self, arg):
|
||||
self.st = self.st.flip(arg)
|
||||
self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i]))
|
||||
return self
|
||||
|
||||
def shrink(self, arg):
|
||||
@@ -58,11 +58,6 @@ class CheckingShapeTracker:
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
return self
|
||||
|
||||
def stride(self, arg):
|
||||
self.st = self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
return self
|
||||
|
||||
def __getitem__(self, val):
|
||||
return self.t.flatten()[val]
|
||||
|
||||
@@ -579,20 +574,18 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
|
||||
self.st.shrink(((1, 2), (1, 3), (1, 3)))
|
||||
self.st.reshape((1, 4))
|
||||
self.st.shrink(((0, 1), (1, 3)))
|
||||
print(self.st.st)
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.st)
|
||||
def test_case_2(self):
|
||||
self.st.stride( (1, 1, -2) )
|
||||
self.st.reshape( (3, 6) )
|
||||
self.st.flip( (True, False, True) )
|
||||
self.st.reshape( (3, 9) )
|
||||
self.st.shrink( ((1, 2), (1, 5)) )
|
||||
self.st.stride( (1, -1) )
|
||||
self.st.flip( (True, True) )
|
||||
def test_case_3(self):
|
||||
self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
|
||||
self.st.permute( (1, 0, 2) )
|
||||
self.st.reshape( (4,) )
|
||||
self.st.shrink( ((0, 3),) )
|
||||
self.st.stride( (-1,) )
|
||||
self.st.flip( (True, False) )
|
||||
def test_case_4(self):
|
||||
self.st.reshape( (3, 3, 3, 1) )
|
||||
self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
|
||||
@@ -687,21 +680,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.st.reshape((9,6,1))
|
||||
self.st.expand((9,6,4))
|
||||
|
||||
def test_pad_stride(self):
|
||||
def test_pad_flip(self):
|
||||
self.st.pad(((1,4), (1,3)))
|
||||
self.st.stride((2,2))
|
||||
self.st.flip((True, False))
|
||||
|
||||
def test_pad_stride_neg(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-1,-1))
|
||||
|
||||
def test_pad_stride_both(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-2,-2))
|
||||
|
||||
def test_shrink_pad(self):
|
||||
self.st.shrink(((0,4), (0,4)))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
def test_pad_flip_int(self):
|
||||
self.st.pad(((1,4), (1,3)))
|
||||
self.st.flip((0, 1))
|
||||
|
||||
def test_reshape(self):
|
||||
new_shape = self.st.shape[::-1]
|
||||
@@ -722,13 +707,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.apply(lambda x: x.expand(tuple(new_shape)))
|
||||
|
||||
def test_flip_0(self):
|
||||
self.apply(lambda x: x.flip((0,)))
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
|
||||
def test_flip_1(self):
|
||||
self.apply(lambda x: x.flip((1,)))
|
||||
self.apply(lambda x: x.flip((False, True)))
|
||||
|
||||
def test_flip_01(self):
|
||||
self.apply(lambda x: x.flip((0,1)))
|
||||
self.apply(lambda x: x.flip((True, True)))
|
||||
|
||||
def test_slice_0(self):
|
||||
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
|
||||
@@ -754,16 +739,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
|
||||
self.apply(lambda x: x.expand((2, 10)))
|
||||
|
||||
def test_double_stride(self):
|
||||
self.apply(lambda x: x.stride((1, 2)))
|
||||
self.apply(lambda x: x.stride((2, 1)))
|
||||
def test_double_flip(self):
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
|
||||
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
|
||||
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
|
||||
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
|
||||
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
|
||||
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
|
||||
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
|
||||
def test_flip(self): self.apply(lambda x: x.flip((True, False)))
|
||||
def test_flip2(self): self.apply(lambda x: x.flip((False, True)))
|
||||
def test_flip3(self): self.apply(lambda x: x.flip((True, True)))
|
||||
|
||||
def test_reshape_then_permute(self):
|
||||
self.test_reshape()
|
||||
@@ -838,24 +820,12 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
self.assertEqual(st.real_size(), 10)
|
||||
|
||||
def test_pad_size_multiview(self):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
# TODO improve real_size accuracy in cases like this?
|
||||
@unittest.expectedFailure
|
||||
def test_stride_size(self):
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertEqual(st1.real_size(), 78)
|
||||
self.assertEqual(st2.real_size(), 65)
|
||||
|
||||
def test_stride_size_bounds(self):
|
||||
# lower bound checks that real_size doesn't give false positive for fitting in a buffer
|
||||
# upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,))
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertTrue(78 <= st1.real_size() <= 100)
|
||||
self.assertTrue(65 <= st2.real_size() <= 100)
|
||||
def test_flip_size(self):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
class TestConsecutive(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ class MultiShapeTracker:
|
||||
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
||||
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
||||
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
||||
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
||||
def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts]
|
||||
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
||||
|
||||
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
|
||||
@@ -75,8 +75,8 @@ class TestShapeTrackerAdd(unittest.TestCase):
|
||||
backup = st.sts[0]
|
||||
st.sts.append(ShapeTracker.from_shape(backup.shape))
|
||||
st.reshape( (45,) )
|
||||
st.stride( (4,) )
|
||||
st.reshape( (4, 3) )
|
||||
st.flip( (True,) )
|
||||
st.reshape( (15, 3) )
|
||||
assert st_equal(backup + st.sts[1], st.sts[0])
|
||||
|
||||
def test_off_by_one(self):
|
||||
@@ -155,22 +155,17 @@ class TestShapeTrackerInvert(unittest.TestCase):
|
||||
|
||||
def test_can_invert_flip(self):
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.stride((-1,1))
|
||||
x = a.flip((True,False))
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_can_invert_flip_permute(self):
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.permute((1,0))
|
||||
x = x.stride((-1,1))
|
||||
x = x.flip((True,False))
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_cant_invert_stride(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
x = a.stride((2,2))
|
||||
assert x.invert(a.shape) is None
|
||||
|
||||
def test_invert_failure(self):
|
||||
a = ShapeTracker.from_shape((2, 5))
|
||||
x = a.pad( ((2, 0), (0, 0)) )
|
||||
|
||||
Reference in New Issue
Block a user