Fix reshape merging with masks (#2877)

This commit is contained in:
Peter Cawley
2023-12-20 22:00:58 +00:00
committed by GitHub
parent 8fe24038d8
commit dae8976889
2 changed files with 19 additions and 32 deletions

View File

@@ -256,7 +256,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_combining_1(self):
self.st = CheckingShapeTracker((2,1,10))
self.st.pad(((2,6), (0,0), (0,0)))
@@ -264,7 +263,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_combining_2(self):
self.st = CheckingShapeTracker((1,1,5))
self.st.pad(((3,6), (0,0), (0,5)))
@@ -272,7 +270,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_combining_3(self):
self.st = CheckingShapeTracker((1,1,4))
self.st.pad(((3,6), (0,0), (1,5)))
@@ -281,7 +278,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert self.st.views[0].mask[0] == (31, 35)
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_combining_4(self):
# interestingly this one is quite slow
self.st = CheckingShapeTracker((1,1,5,5,1,1,5))
@@ -290,7 +286,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_splitting_combining(self):
self.st = CheckingShapeTracker((1,5,5))
self.st.pad(((0,4), (0,5), (0,0)))
@@ -298,7 +293,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_reshape_only_1s(self):
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
@@ -342,7 +336,6 @@ class TestIndexExpressions2d(unittest.TestCase):
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.expectedFailure
def test_splitting_big(self):
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
@@ -351,10 +344,9 @@ class TestIndexExpressions2d(unittest.TestCase):
self.st.reshape((2,3,5,2,5))
assert len(self.st.views) == 1
v = self.st.views[-1]
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
assert v.strides == (0, 5, 1, 0, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
self.st.assert_same()
@unittest.expectedFailure
def test_combining_big(self):
self.st = CheckingShapeTracker((1,3,1,5,3,1))
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))