mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
simplify get_contraction (#1373)
This commit is contained in:
4
test/external/external_test_opt.py
vendored
4
test/external/external_test_opt.py
vendored
@@ -294,7 +294,7 @@ class TestOpt(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_though_contract_reshape(self):
|
||||
def test_permute_was_pushed_through_contract_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
c = a.sum(-1)
|
||||
@@ -304,7 +304,7 @@ class TestOpt(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
|
||||
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
|
||||
|
||||
def test_permute_was_pushed_though_contractw1s_reshape(self):
|
||||
def test_permute_was_pushed_through_contractw1s_reshape(self):
|
||||
a = Tensor.randn(4, 4, 4, 4, 4)
|
||||
with CLCache():
|
||||
c = a.sum(-1)
|
||||
|
||||
@@ -571,5 +571,24 @@ class TestGetContraction(unittest.TestCase):
|
||||
r = get_contraction((1,2,3,4), (1,2,6,2))
|
||||
self.assertEqual(r, None)
|
||||
|
||||
def test_contraction_ones(self):
|
||||
r = get_contraction((1,), (1,1,1))
|
||||
self.assertEqual(r, [[0], [], []])
|
||||
|
||||
r = get_contraction((1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], []])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,))
|
||||
self.assertEqual(r, [[0,1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1))
|
||||
self.assertEqual(r, [[0], [1,2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2,3]])
|
||||
|
||||
r = get_contraction((1,1,1,1), (1,1,1,1))
|
||||
self.assertEqual(r, [[0], [1], [2], [3]])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -282,11 +282,11 @@ def get_contraction(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Opt
|
||||
if new_shape[i] == 1 and old_shape[old_shape_i] != 1:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
else:
|
||||
if new_shape[i] % old_shape[old_shape_i] != 0 or prod([old_shape[x] for x in axis_groups[i]]) * old_shape[old_shape_i] > new_shape[i]:
|
||||
return None
|
||||
axis_groups[i].append(old_shape_i)
|
||||
axis_group_size = prod([old_shape[x] for x in axis_groups[i]])
|
||||
# Move to next axes group if total size of all dimensions match.
|
||||
if prod([old_shape[x] for x in axis_groups[i]]) == new_shape[i]:
|
||||
if axis_group_size == new_shape[i]:
|
||||
if i < len(new_shape) - 1: i += 1
|
||||
elif axis_group_size > new_shape[i]: return None
|
||||
old_shape_i += 1
|
||||
return axis_groups
|
||||
|
||||
Reference in New Issue
Block a user