losing lines (#678)

* losing lines

* FLIP -> STRIDE

* shapetracker refactor
This commit is contained in:
George Hotz
2023-03-10 21:57:05 -08:00
committed by GitHub
parent d7cb8e3e56
commit 0b03216cc3
13 changed files with 103 additions and 109 deletions

View File

@@ -23,31 +23,31 @@ class CheckingShapeTracker:
def simplify(self): self.st.simplify()
def reshape(self, new_shape):
self.st.reshape(new_shape)
self.st._reshape(new_shape)
self.t = self.t.reshape(new_shape)
def permute(self, axis):
self.st.permute(axis)
self.st._permute(axis)
self.t = np.transpose(self.t, axis)
def expand(self, new_shape):
self.st.expand(new_shape)
self.st._expand(new_shape)
self.t = np.broadcast_to(self.t, new_shape)
def flip(self, axis):
self.st.flip(axis)
self.st._stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
def shrink(self, arg):
self.st.shrink(arg)
self.st._shrink(arg)
self.t = self.t[tuple([slice(x[0], x[1]) for x in arg])]
def pad(self, arg):
self.st.pad(arg)
self.st._pad(arg)
self.t = np.pad(self.t, arg, constant_values=-1)
def stride(self, arg):
self.st.stride(arg)
self.st._stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
def __getitem__(self, val):
@@ -148,7 +148,7 @@ class TestSimplifyingShapeTracker(unittest.TestCase):
class TestComplexShapeTracker(unittest.TestCase):
def test_add_1s(self):
self.st = ShapeTracker((4, 4))
self.st = CheckingShapeTracker((4, 4))
self.st.permute((1,0))
self.st.reshape((1,4,1,4,1))
assert not self.st.contiguous
@@ -156,20 +156,20 @@ class TestComplexShapeTracker(unittest.TestCase):
assert self.st.contiguous
def test_permute_1s_simple(self):
self.st = ShapeTracker((1, 16, 9,9))
self.st = CheckingShapeTracker((1, 16, 9,9))
self.st.permute((1,0,2,3))
assert self.st.contiguous
self.st = ShapeTracker((2, 16, 9,9))
self.st = CheckingShapeTracker((2, 16, 9,9))
self.st.permute((1,0,2,3))
assert not self.st.contiguous
def test_remove_1s_simple(self):
self.st = ShapeTracker((1, 16, 1, 1))
self.st = CheckingShapeTracker((1, 16, 1, 1))
self.st.reshape((16,))
assert self.st.contiguous
def test_remove_1s(self):
self.st = ShapeTracker((1, 4, 1, 4, 1))
self.st = CheckingShapeTracker((1, 4, 1, 4, 1))
self.st.permute((0,3,2,1,4))
self.st.reshape((4,4))
assert not self.st.contiguous
@@ -177,46 +177,46 @@ class TestComplexShapeTracker(unittest.TestCase):
assert self.st.contiguous
def test_permute_reshape(self):
self.st = ShapeTracker((4, 4))
self.st = CheckingShapeTracker((4, 4))
self.st.permute((1,0))
self.st.reshape((2, 2, 2, 2))
# TODO: should also be tested by test_super_complex
assert len(self.st.views) == 1
def test_factorize_split(self):
self.st = ShapeTracker((4, 4))
self.st = CheckingShapeTracker((4, 4))
self.st.permute((1,0))
self.st.reshape((2, 2, 2, 2))
self.st.permute((2,3,0,1))
assert self.st.contiguous
def test_factorize_combine(self):
self.st = ShapeTracker((4, 4, 4))
self.st = CheckingShapeTracker((4, 4, 4))
self.st.permute((2, 0, 1))
self.st.reshape((4, 16))
self.st.permute((1, 0))
assert self.st.contiguous
def test_factorize_combine_add_ones(self):
self.st = ShapeTracker((4, 4, 4))
self.st = CheckingShapeTracker((4, 4, 4))
self.st.permute((2, 0, 1))
self.st.reshape((4, 16, 1, 1))
self.st.permute((1, 0, 2, 3))
assert self.st.contiguous
def test_fancy_factorize(self):
self.st = ShapeTracker((32, 3, 3, 1))
self.st = CheckingShapeTracker((32, 3, 3, 1))
self.st.reshape((8, 4, 3, 3))
assert len(self.st.views) == 1
def test_super_complex_2_fail(self):
self.st = ShapeTracker((4, 4, 4))
self.st = CheckingShapeTracker((4, 4, 4))
self.st.permute((2, 0, 1))
self.st.reshape((16, 4))
assert len(self.st.views) != 1
def test_work(self):
self.st = ShapeTracker((64, 1024, 4))
self.st = CheckingShapeTracker((64, 1024, 4))
self.st.reshape((1, 64, 128, 32))
self.st.permute((0, 3, 1, 2))
self.st.reshape((1, 32, 1, 64, 128))
@@ -224,7 +224,7 @@ class TestComplexShapeTracker(unittest.TestCase):
assert self.st.contiguous
def test_work2(self):
self.st = ShapeTracker((64, 1024, 4))
self.st = CheckingShapeTracker((64, 1024, 4))
self.st.reshape((1, 64, 128, 32))
self.st.permute((0, 3, 1, 2))
self.st.reshape((1, 1, 32, 64, 128))