diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 6f4ef2819a..dd5e14ff28 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -294,7 +294,7 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" - def test_permute_was_pushed_though_contract_reshape(self): + def test_permute_was_pushed_through_contract_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) with CLCache(): c = a.sum(-1) @@ -304,7 +304,7 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" - def test_permute_was_pushed_though_contractw1s_reshape(self): + def test_permute_was_pushed_through_contractw1s_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) with CLCache(): c = a.sum(-1) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index d823a882e2..56d25248ea 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -571,5 +571,24 @@ class TestGetContraction(unittest.TestCase): r = get_contraction((1,2,3,4), (1,2,6,2)) self.assertEqual(r, None) + def test_contraction_ones(self): + r = get_contraction((1,), (1,1,1)) + self.assertEqual(r, [[0], [], []]) + + r = get_contraction((1,1), (1,1,1)) + self.assertEqual(r, [[0], [1], []]) + + r = get_contraction((1,1,1,1), (1,)) + self.assertEqual(r, [[0,1,2,3]]) + + r = get_contraction((1,1,1,1), (1,1)) + self.assertEqual(r, [[0], [1,2,3]]) + + r = get_contraction((1,1,1,1), (1,1,1)) + self.assertEqual(r, [[0], [1], [2,3]]) + + r = get_contraction((1,1,1,1), (1,1,1,1)) + self.assertEqual(r, [[0], [1], [2], [3]]) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index de8cdfc226..c4fb84a3e5 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -282,11 +282,11 @@ def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Opt if new_shape[i] == 1 and old_shape[old_shape_i] != 1: if i < len(new_shape) - 1: i += 1 else: - if new_shape[i] % old_shape[old_shape_i] != 0 or prod([old_shape[x] for x in axis_groups[i]]) * old_shape[old_shape_i] > new_shape[i]: - return None axis_groups[i].append(old_shape_i) + axis_group_size = prod([old_shape[x] for x in axis_groups[i]]) # Move to next axes group if total size of all dimensions match. - if prod([old_shape[x] for x in axis_groups[i]]) == new_shape[i]: + if axis_group_size == new_shape[i]: if i < len(new_shape) - 1: i += 1 + elif axis_group_size > new_shape[i]: return None old_shape_i += 1 return axis_groups