simplify get_contraction (#1373)

This commit is contained in:
chenyu
2023-07-30 18:35:22 -07:00
committed by GitHub
parent a32c677601
commit 1fdf560fb1
3 changed files with 24 additions and 5 deletions

View File

@@ -294,7 +294,7 @@ class TestOpt(unittest.TestCase):
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_though_contract_reshape(self):
def test_permute_was_pushed_through_contract_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
c = a.sum(-1)
@@ -304,7 +304,7 @@ class TestOpt(unittest.TestCase):
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_though_contractw1s_reshape(self):
def test_permute_was_pushed_through_contractw1s_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
c = a.sum(-1)

View File

@@ -571,5 +571,24 @@ class TestGetContraction(unittest.TestCase):
r = get_contraction((1,2,3,4), (1,2,6,2))
self.assertEqual(r, None)
def test_contraction_ones(self):
r = get_contraction((1,), (1,1,1))
self.assertEqual(r, [[0], [], []])
r = get_contraction((1,1), (1,1,1))
self.assertEqual(r, [[0], [1], []])
r = get_contraction((1,1,1,1), (1,))
self.assertEqual(r, [[0,1,2,3]])
r = get_contraction((1,1,1,1), (1,1))
self.assertEqual(r, [[0], [1,2,3]])
r = get_contraction((1,1,1,1), (1,1,1))
self.assertEqual(r, [[0], [1], [2,3]])
r = get_contraction((1,1,1,1), (1,1,1,1))
self.assertEqual(r, [[0], [1], [2], [3]])
if __name__ == '__main__':
unittest.main()

View File

@@ -282,11 +282,11 @@ def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Opt
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
if i < len(new_shape) - 1: i += 1
else:
if new_shape[i] % old_shape[old_shape_i] != 0 or prod([old_shape[x] for x in axis_groups[i]]) * old_shape[old_shape_i] > new_shape[i]:
return None
axis_groups[i].append(old_shape_i)
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
# Move to next axes group if total size of all dimensions match.
if prod([old_shape[x] for x in axis_groups[i]]) == new_shape[i]:
if axis_group_size == new_shape[i]:
if i < len(new_shape) - 1: i += 1
elif axis_group_size > new_shape[i]: return None
old_shape_i += 1
return axis_groups