mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -122,6 +122,7 @@ class TestRealSimplifies(unittest.TestCase):
|
||||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
@@ -187,7 +188,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
||||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
@@ -417,94 +417,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshaping_splitting(self):
|
||||
self.st = CheckingShapeTracker((5,10,5,10))
|
||||
self.st.permute((1, 0, 3, 2))
|
||||
self.st.pad(((0,0), (0,5), (0,0), (0,5)))
|
||||
self.st.reshape((10,2,5,10,2,5))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshape_combining_1(self):
|
||||
self.st = CheckingShapeTracker((2,1,10))
|
||||
self.st.pad(((2,6), (0,0), (0,0)))
|
||||
self.st.reshape((100,))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.skip("Can't make this optimization yet")
|
||||
def test_reshape_combining_2(self):
|
||||
self.st = CheckingShapeTracker((1,1,5))
|
||||
self.st.pad(((3,6), (0,0), (0,5)))
|
||||
self.st.reshape((100,))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.skip("Can't make this optimization yet")
|
||||
def test_reshape_splitting_combining(self):
|
||||
self.st = CheckingShapeTracker((1,5,5))
|
||||
self.st.pad(((0,4), (0,5), (0,0)))
|
||||
self.st.reshape((10,25))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshape_only_1s(self):
|
||||
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
|
||||
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
|
||||
self.st.reshape((5, 6, 3, 5))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_zero_mask_1(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2))
|
||||
self.st.pad(((0,0), (0,3), (0,0)))
|
||||
self.st.shrink(((0,1), (3,6), (0,2)))
|
||||
self.st.reshape((3,2))
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 3, 1, 2, 1))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_zero_mask_2(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2))
|
||||
self.st.pad(((0,2), (0,3), (0,0)))
|
||||
self.st.shrink(((2,3), (3,6), (0,2)))
|
||||
self.st.reshape((3,2))
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 3, 1, 2, 1))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_expanded_reshaped(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2, 1))
|
||||
self.st.expand((5, 3, 2, 2))
|
||||
self.st.pad(((0,0), (0,3), (0,0), (0, 0)))
|
||||
self.st.reshape((5, 2, 3, 2, 2))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_splitting_big(self):
|
||||
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
|
||||
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
|
||||
self.st.reshape((10, 1, 30))
|
||||
self.st.permute((2,1,0))
|
||||
self.st.reshape((2,3,5,2,5))
|
||||
assert len(self.st.views) == 1
|
||||
v = self.st.views[-1]
|
||||
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
|
||||
|
||||
def test_combining_big(self):
|
||||
self.st = CheckingShapeTracker((1,3,1,5,3,1))
|
||||
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
|
||||
self.st.reshape((1,1,1,105,1,1))
|
||||
assert len(self.st.views) == 1
|
||||
v = self.st.views[-1]
|
||||
assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30
|
||||
|
||||
class TestShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
||||
Reference in New Issue
Block a user