mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
support einsum trace (#14012)
* support einsum trace * test_einsum_scalar_cpu
This commit is contained in:
3
test/external/external_test_onnx_backend.py
vendored
3
test/external/external_test_onnx_backend.py
vendored
@@ -151,8 +151,6 @@ backend_test.exclude('test_hannwindow_*')
|
||||
backend_test.exclude('test_hardmax_*')
|
||||
backend_test.exclude('test_gridsample_*')
|
||||
backend_test.exclude('test_dft_*')
|
||||
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
|
||||
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
|
||||
backend_test.exclude('test_unique_*')
|
||||
backend_test.exclude('test_sequence_*')
|
||||
backend_test.exclude('test_nonmaxsuppression_*')
|
||||
@@ -175,7 +173,6 @@ backend_test.exclude('test_tensorscatter_*')
|
||||
backend_test.exclude('test_l1normalization_*')
|
||||
backend_test.exclude('test_l2normalization_*')
|
||||
backend_test.exclude('test_lpnormalization_*')
|
||||
backend_test.exclude('test_einsum_scalar_cpu')
|
||||
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
|
||||
backend_test.exclude('test_qlinearmatmul_2D_uint8_float16_cpu')
|
||||
backend_test.exclude('test_qlinearmatmul_3D_uint8_float16_cpu')
|
||||
|
||||
@@ -1171,6 +1171,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
@slow_test
|
||||
def test_einsum(self):
|
||||
# scalar
|
||||
helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a))
|
||||
# matrix transpose
|
||||
helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
|
||||
helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
|
||||
@@ -1239,6 +1241,18 @@ class TestOps(unittest.TestCase):
|
||||
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
|
||||
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError)
|
||||
|
||||
def test_einsum_trace(self):
|
||||
# inner product
|
||||
helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b))
|
||||
# simple diagonal
|
||||
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a))
|
||||
# trace (sum of diagonal)
|
||||
helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a))
|
||||
# batch diagonal
|
||||
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a))
|
||||
# batch trace
|
||||
helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a))
|
||||
|
||||
def test_einsum_shape_check(self):
|
||||
self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]),
|
||||
lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError)
|
||||
|
||||
Reference in New Issue
Block a user