Revert "Better reshape (#1423)" (#1538)

This commit is contained in:
wozeparrot
2023-08-14 13:04:54 -04:00
committed by GitHub
parent cf2bf1518d
commit 9cb2bda34f
2 changed files with 22 additions and 137 deletions

View File

@@ -122,6 +122,7 @@ class TestRealSimplifies(unittest.TestCase):
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
@@ -187,7 +188,6 @@ class TestIndexExpressions2d(unittest.TestCase):
st.expand((base_shape[0], base_shape[1], base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
def test_permute_reshape_1(self): # This tests multiple views
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
@@ -417,94 +417,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
self.st.pad(((1,1), (1,1)))
self.st.assert_same()
def test_reshaping_splitting(self):
self.st = CheckingShapeTracker((5,10,5,10))
self.st.permute((1, 0, 3, 2))
self.st.pad(((0,0), (0,5), (0,0), (0,5)))
self.st.reshape((10,2,5,10,2,5))
assert len(self.st.views) == 1
self.st.assert_same()
def test_reshape_combining_1(self):
self.st = CheckingShapeTracker((2,1,10))
self.st.pad(((2,6), (0,0), (0,0)))
self.st.reshape((100,))
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.skip("Can't make this optimization yet")
def test_reshape_combining_2(self):
self.st = CheckingShapeTracker((1,1,5))
self.st.pad(((3,6), (0,0), (0,5)))
self.st.reshape((100,))
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.skip("Can't make this optimization yet")
def test_reshape_splitting_combining(self):
self.st = CheckingShapeTracker((1,5,5))
self.st.pad(((0,4), (0,5), (0,0)))
self.st.reshape((10,25))
assert len(self.st.views) == 1
self.st.assert_same()
def test_reshape_only_1s(self):
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
self.st.reshape((5, 6, 3, 5))
assert len(self.st.views) == 1
self.st.assert_same()
self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1))
assert len(self.st.views) == 1
self.st.assert_same()
self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1))
assert len(self.st.views) == 1
self.st.assert_same()
def test_zero_mask_1(self):
self.st = CheckingShapeTracker((1, 3, 2))
self.st.pad(((0,0), (0,3), (0,0)))
self.st.shrink(((0,1), (3,6), (0,2)))
self.st.reshape((3,2))
self.st.assert_same()
self.st.reshape((1, 3, 1, 2, 1))
self.st.assert_same()
def test_zero_mask_2(self):
self.st = CheckingShapeTracker((1, 3, 2))
self.st.pad(((0,2), (0,3), (0,0)))
self.st.shrink(((2,3), (3,6), (0,2)))
self.st.reshape((3,2))
self.st.assert_same()
self.st.reshape((1, 3, 1, 2, 1))
self.st.assert_same()
def test_expanded_reshaped(self):
self.st = CheckingShapeTracker((1, 3, 2, 1))
self.st.expand((5, 3, 2, 2))
self.st.pad(((0,0), (0,3), (0,0), (0, 0)))
self.st.reshape((5, 2, 3, 2, 2))
assert len(self.st.views) == 1
self.st.assert_same()
def test_splitting_big(self):
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
self.st.reshape((10, 1, 30))
self.st.permute((2,1,0))
self.st.reshape((2,3,5,2,5))
assert len(self.st.views) == 1
v = self.st.views[-1]
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
def test_combining_big(self):
self.st = CheckingShapeTracker((1,3,1,5,3,1))
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
self.st.reshape((1,1,1,105,1,1))
assert len(self.st.views) == 1
v = self.st.views[-1]
assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30
class TestShapeTracker(unittest.TestCase):
def setUp(self):
self.st = CheckingShapeTracker((7,4))