Add einsum tests (#6286)

Co-authored-by: Maximilian Weichart <maximilian.weichart@icloud.com>
This commit is contained in:
Max-We
2024-08-26 18:09:25 +02:00
committed by GitHub
parent b76f0c875e
commit ab2714423b

View File

@@ -752,10 +752,28 @@ class TestOps(unittest.TestCase):
lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5)
# bilinear transformation
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
# test ellipsis # TODO: FIXME
with self.assertRaises(Exception):
helper_test_op([(16,29,256),(16,29,256)], lambda a,b: torch.einsum('...id, ...jd -> ...ij', [a,b]),
lambda a,b: Tensor.einsum('...id, ...jd -> ...ij', [a,b]))
@unittest.expectedFailure
def test_einsum_ellipsis(self):
"""The expected behavior for einsum is described in the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.einsum.html"""
# TODO: implement ellipsis support in einsum to pass these tests
# test ellipsis
helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd -> ...ij', [a, b]),
lambda a, b: Tensor.einsum('...id, ...jd -> ...ij', [a, b]))
# ellipsis will come first in the output before the subscript labels, if rhs is not specified
helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd', [a, b]),
lambda a, b: Tensor.einsum('...id, ...jd', [a, b]))
# multiple ellipsis in different operands with different shapes are allowed
helper_test_op([(2, 3, 4, 5), (5, 2, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]))
# multiple ellipsis in one operand are not allowed. This test shall raise an exception.
with self.assertRaises(RuntimeError):
helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]),
lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]))
# multiple ellipsis must broadcast together. This test shall raise an exception.
with self.assertRaises(RuntimeError):
helper_test_op([(2, 3, 4, 5), (5, 2, 7)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]))
def test_einsum_shape_check(self):
a = Tensor.zeros(3,8,10,5)