mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Fix reshape merging with masks (#2877)
This commit is contained in:
@@ -256,7 +256,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_combining_1(self):
|
||||
self.st = CheckingShapeTracker((2,1,10))
|
||||
self.st.pad(((2,6), (0,0), (0,0)))
|
||||
@@ -264,7 +263,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_combining_2(self):
|
||||
self.st = CheckingShapeTracker((1,1,5))
|
||||
self.st.pad(((3,6), (0,0), (0,5)))
|
||||
@@ -272,7 +270,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_combining_3(self):
|
||||
self.st = CheckingShapeTracker((1,1,4))
|
||||
self.st.pad(((3,6), (0,0), (1,5)))
|
||||
@@ -281,7 +278,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert self.st.views[0].mask[0] == (31, 35)
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_combining_4(self):
|
||||
# interestingly this one is quite slow
|
||||
self.st = CheckingShapeTracker((1,1,5,5,1,1,5))
|
||||
@@ -290,7 +286,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_splitting_combining(self):
|
||||
self.st = CheckingShapeTracker((1,5,5))
|
||||
self.st.pad(((0,4), (0,5), (0,0)))
|
||||
@@ -298,7 +293,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_reshape_only_1s(self):
|
||||
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
|
||||
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
|
||||
@@ -342,7 +336,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_splitting_big(self):
|
||||
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
|
||||
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
|
||||
@@ -351,10 +344,9 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
self.st.reshape((2,3,5,2,5))
|
||||
assert len(self.st.views) == 1
|
||||
v = self.st.views[-1]
|
||||
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
|
||||
assert v.strides == (0, 5, 1, 0, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_combining_big(self):
|
||||
self.st = CheckingShapeTracker((1,3,1,5,3,1))
|
||||
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
|
||||
|
||||
Reference in New Issue
Block a user