support einsum trace (#14012)

* support einsum trace

* test_einsum_scalar_cpu
This commit is contained in:
chenyu
2026-01-04 19:27:27 -05:00
committed by GitHub
parent 404eed6172
commit f6a78a29e0
3 changed files with 30 additions and 9 deletions

View File

@@ -151,8 +151,6 @@ backend_test.exclude('test_hannwindow_*')
backend_test.exclude('test_hardmax_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
backend_test.exclude('test_nonmaxsuppression_*')
@@ -175,7 +173,6 @@ backend_test.exclude('test_tensorscatter_*')
backend_test.exclude('test_l1normalization_*')
backend_test.exclude('test_l2normalization_*')
backend_test.exclude('test_lpnormalization_*')
backend_test.exclude('test_einsum_scalar_cpu')
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
backend_test.exclude('test_qlinearmatmul_2D_uint8_float16_cpu')
backend_test.exclude('test_qlinearmatmul_3D_uint8_float16_cpu')

View File

@@ -1171,6 +1171,8 @@ class TestOps(unittest.TestCase):
@slow_test
def test_einsum(self):
# scalar
helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a))
# matrix transpose
helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
@@ -1239,6 +1241,18 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError)
def test_einsum_trace(self):
# inner product
helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b))
# simple diagonal
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a))
# trace (sum of diagonal)
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a))
# batch diagonal
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a))
# batch trace
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a))
def test_einsum_shape_check(self):
self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]),
lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError)