mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move some ifs from merge_dims to reshape [pr] (#8229)
the third return value is only used in reshape
This commit is contained in:
@@ -50,7 +50,7 @@ class TestMergeDims(unittest.TestCase):
|
||||
shape = (2, 3, 4)
|
||||
self.assertEqual(merge_dims(shape, (0, 4, 1)), ((2, 0, 0), (12, 1, 12)))
|
||||
self.assertEqual(merge_dims(shape, (0, 0, 1)), ((6, 0, 0), (4, 1, 4)))
|
||||
self.assertEqual(merge_dims(shape, (3, 1, 0)), ((6, 1, 6), (4, 0, 0)))
|
||||
self.assertEqual(merge_dims(shape, (3, 1, 0)), ((6, 1, 6), (4, 0, 4)))
|
||||
self.assertEqual(merge_dims(shape, (0, 0, 0)), ((24, 0, 0),))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
@@ -59,16 +59,28 @@ class TestMergeDims(unittest.TestCase):
|
||||
# shift mask on stride 0
|
||||
self.assertEqual(merge_dims((2, 3), (0, 1), ((0, 1), (0, 2))), ((6, 1, 3),))
|
||||
# permute 0 / 1
|
||||
self.assertEqual(merge_dims((3, 2), (1, 0), ((0, 2), (1, 2))), ((3, 1, 3), (2, 0, 0)))
|
||||
self.assertEqual(merge_dims((3, 2), (1, 0), ((0, 2), (1, 2))), ((3, 1, 3), (2, 0, 2)))
|
||||
|
||||
# st = ShapeTracker.from_shape((1, 1, 2)).pad(((1, 0), (1, 0), (0, 1)))
|
||||
# print(f"{st.views[-1]}")
|
||||
self.assertEqual(merge_dims((2, 2, 3), (0, 0, 1), ((1, 2), (1, 2), (0, 2))), ((12, 1, 3),))
|
||||
|
||||
# st = ShapeTracker.from_shape((1, 1, 2, 2)).pad(((1, 0), (1, 0), (0, 1), (0, 1)))
|
||||
# print(f"{st.views[-1]}")
|
||||
self.assertEqual(merge_dims((2, 2, 3, 3), (0, 0, 2, 1), ((1, 2), (1, 2), (0, 2), (0, 2))), ((12, 2, 3), (3, 1, 3)))
|
||||
|
||||
# st = ShapeTracker.from_shape((1, 1, 1, 1)).pad(((0, 2), (0, 0), (0, 1), (0, 3)))
|
||||
# print(f"{st.views[-1]}")
|
||||
# self.assertEqual(merge_dims((3, 1, 2, 4), (0, 0, 0, 0), ((0, 1), (0, 1), (0, 1), (0, 1))), ((24, 0, 0),))
|
||||
|
||||
def test_different_1_pad(self):
|
||||
# st = ShapeTracker.from_shape((2, 2, 1)).pad(((0, 0), (0, 0), (0, 1)))
|
||||
# print(f"{st.views[-1]}")
|
||||
self.assertEqual(merge_dims((2, 2, 2), (2, 1, 0), ((0, 2), (0, 2), (0, 1))), ((4, 1, 4), (2, 0, 0)))
|
||||
self.assertEqual(merge_dims((2, 2, 2), (2, 1, 0), ((0, 2), (0, 2), (0, 1))), ((4, 1, 4), (2, 0, 2)))
|
||||
|
||||
# st = ShapeTracker.from_shape((2, 1, 1)).pad(((0, 0), (0, 1), (0, 1)))
|
||||
# print(f"{st.views[-1]}")
|
||||
self.assertEqual(merge_dims((2, 2, 2), (1, 0, 0), ((0, 2), (0, 2), (0, 1))), ((2, 1, 2), (4, 0, 0)))
|
||||
self.assertEqual(merge_dims((2, 2, 2), (1, 0, 0), ((0, 2), (0, 2), (0, 1))), ((2, 1, 2), (4, 0, 4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user