[bounty] Stride is flip (#8792)

* replace stride with flip

* Complete replacing stride with flip

clean flip function in view.py
fix tests

* fix tests for multi shapetracker

* fix tests for fuzz shapetracker

* fix tests for fuzz shapetracker

* debug

* debug

* fix

* fix

* fix

---------

Co-authored-by: George Hotz <geohot@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ankit Avinash
2025-01-31 08:04:10 +05:30
committed by GitHub
parent 0513b0c17d
commit 7647cd8428
9 changed files with 50 additions and 97 deletions

View File

@@ -38,17 +38,10 @@ def do_shrink(st):
if DEBUG >= 1: print("st.shrink(", shrink, ")")
st.shrink(shrink)
def do_stride(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
if DEBUG >= 1: print("st.stride(", stride, ")")
st.stride(stride)
def do_flip(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
if DEBUG >= 1: print("st.stride(", stride, ")")
st.stride(stride)
flip = tuple(random.random() < 0.5 for _ in st.shape)
if DEBUG >= 1: print("st.flip(", flip, ")")
st.flip(flip)
def do_expand(st):
c = [i for i,s in enumerate(st.shape) if s==1]
@@ -58,7 +51,7 @@ def do_expand(st):
if DEBUG >= 1: print("st.expand(", expand, ")")
st.expand(expand)
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_flip, do_expand]
if __name__ == "__main__":
random.seed(42)

View File

@@ -43,9 +43,9 @@ class CheckingShapeTracker:
self.t = np.broadcast_to(self.t, new_shape)
return self
def flip(self, axis):
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
def flip(self, arg):
self.st = self.st.flip(arg)
self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i]))
return self
def shrink(self, arg):
@@ -58,11 +58,6 @@ class CheckingShapeTracker:
self.t = np.pad(self.t, arg, constant_values=-1)
return self
def stride(self, arg):
self.st = self.st.stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
return self
def __getitem__(self, val):
return self.t.flatten()[val]
@@ -579,20 +574,18 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
self.st.shrink(((1, 2), (1, 3), (1, 3)))
self.st.reshape((1, 4))
self.st.shrink(((0, 1), (1, 3)))
print(self.st.st)
self.st = self.st.simplify()
print(self.st.st)
def test_case_2(self):
self.st.stride( (1, 1, -2) )
self.st.reshape( (3, 6) )
self.st.flip( (True, False, True) )
self.st.reshape( (3, 9) )
self.st.shrink( ((1, 2), (1, 5)) )
self.st.stride( (1, -1) )
self.st.flip( (True, True) )
def test_case_3(self):
self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
self.st.permute( (1, 0, 2) )
self.st.reshape( (4,) )
self.st.shrink( ((0, 3),) )
self.st.stride( (-1,) )
self.st.flip( (True, False) )
def test_case_4(self):
self.st.reshape( (3, 3, 3, 1) )
self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
@@ -687,21 +680,13 @@ class TestShapeTracker(unittest.TestCase):
self.st.reshape((9,6,1))
self.st.expand((9,6,4))
def test_pad_stride(self):
def test_pad_flip(self):
self.st.pad(((1,4), (1,3)))
self.st.stride((2,2))
self.st.flip((True, False))
def test_pad_stride_neg(self):
self.st.pad(((1,2), (1,0)))
self.st.stride((-1,-1))
def test_pad_stride_both(self):
self.st.pad(((1,2), (1,0)))
self.st.stride((-2,-2))
def test_shrink_pad(self):
self.st.shrink(((0,4), (0,4)))
self.st.pad(((1,1), (1,1)))
def test_pad_flip_int(self):
self.st.pad(((1,4), (1,3)))
self.st.flip((0, 1))
def test_reshape(self):
new_shape = self.st.shape[::-1]
@@ -722,13 +707,13 @@ class TestShapeTracker(unittest.TestCase):
self.apply(lambda x: x.expand(tuple(new_shape)))
def test_flip_0(self):
self.apply(lambda x: x.flip((0,)))
self.apply(lambda x: x.flip((True, False)))
def test_flip_1(self):
self.apply(lambda x: x.flip((1,)))
self.apply(lambda x: x.flip((False, True)))
def test_flip_01(self):
self.apply(lambda x: x.flip((0,1)))
self.apply(lambda x: x.flip((True, True)))
def test_slice_0(self):
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
@@ -754,16 +739,13 @@ class TestShapeTracker(unittest.TestCase):
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
self.apply(lambda x: x.expand((2, 10)))
def test_double_stride(self):
self.apply(lambda x: x.stride((1, 2)))
self.apply(lambda x: x.stride((2, 1)))
def test_double_flip(self):
self.apply(lambda x: x.flip((True, False)))
self.apply(lambda x: x.flip((True, False)))
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
def test_flip(self): self.apply(lambda x: x.flip((True, False)))
def test_flip2(self): self.apply(lambda x: x.flip((False, True)))
def test_flip3(self): self.apply(lambda x: x.flip((True, True)))
def test_reshape_then_permute(self):
self.test_reshape()
@@ -838,24 +820,12 @@ class TestShapeTrackerSize(unittest.TestCase):
self.assertEqual(st.real_size(), 10)
def test_pad_size_multiview(self):
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,))
self.assertEqual(st.real_size(), 100)
# TODO improve real_size accuracy in cases like this?
@unittest.expectedFailure
def test_stride_size(self):
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
self.assertEqual(st1.real_size(), 78)
self.assertEqual(st2.real_size(), 65)
def test_stride_size_bounds(self):
# lower bound checks that real_size doesn't give false positive for fitting in a buffer
# upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,))
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
self.assertTrue(78 <= st1.real_size() <= 100)
self.assertTrue(65 <= st2.real_size() <= 100)
def test_flip_size(self):
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
self.assertEqual(st.real_size(), 100)
class TestConsecutive(unittest.TestCase):
@classmethod

View File

@@ -13,7 +13,7 @@ class MultiShapeTracker:
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts]
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
@@ -75,8 +75,8 @@ class TestShapeTrackerAdd(unittest.TestCase):
backup = st.sts[0]
st.sts.append(ShapeTracker.from_shape(backup.shape))
st.reshape( (45,) )
st.stride( (4,) )
st.reshape( (4, 3) )
st.flip( (True,) )
st.reshape( (15, 3) )
assert st_equal(backup + st.sts[1], st.sts[0])
def test_off_by_one(self):
@@ -155,22 +155,17 @@ class TestShapeTrackerInvert(unittest.TestCase):
def test_can_invert_flip(self):
a = ShapeTracker.from_shape((20, 10))
x = a.stride((-1,1))
x = a.flip((True,False))
ap = x + x.invert(a.shape)
assert st_equal(ap, a)
def test_can_invert_flip_permute(self):
a = ShapeTracker.from_shape((20, 10))
x = a.permute((1,0))
x = x.stride((-1,1))
x = x.flip((True,False))
ap = x + x.invert(a.shape)
assert st_equal(ap, a)
def test_cant_invert_stride(self):
a = ShapeTracker.from_shape((10, 10))
x = a.stride((2,2))
assert x.invert(a.shape) is None
def test_invert_failure(self):
a = ShapeTracker.from_shape((2, 5))
x = a.pad( ((2, 0), (0, 0)) )