simplify get_contraction (#1373)

This commit is contained in:
chenyu
2023-07-30 18:35:22 -07:00
committed by GitHub
parent a32c677601
commit 1fdf560fb1
3 changed files with 24 additions and 5 deletions

View File

@@ -294,7 +294,7 @@ class TestOpt(unittest.TestCase):
np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_though_contract_reshape(self):
def test_permute_was_pushed_through_contract_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
c = a.sum(-1)
@@ -304,7 +304,7 @@ class TestOpt(unittest.TestCase):
np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5)
if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!"
def test_permute_was_pushed_though_contractw1s_reshape(self):
def test_permute_was_pushed_through_contractw1s_reshape(self):
a = Tensor.randn(4, 4, 4, 4, 4)
with CLCache():
c = a.sum(-1)