diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 4e369940fe..0916504bca 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -151,8 +151,6 @@ backend_test.exclude('test_hannwindow_*') backend_test.exclude('test_hardmax_*') backend_test.exclude('test_gridsample_*') backend_test.exclude('test_dft_*') -backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i' -backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i' backend_test.exclude('test_unique_*') backend_test.exclude('test_sequence_*') backend_test.exclude('test_nonmaxsuppression_*') @@ -175,7 +173,6 @@ backend_test.exclude('test_tensorscatter_*') backend_test.exclude('test_l1normalization_*') backend_test.exclude('test_l2normalization_*') backend_test.exclude('test_lpnormalization_*') -backend_test.exclude('test_einsum_scalar_cpu') backend_test.exclude('test_mod_mixed_sign_float16_cpu') backend_test.exclude('test_qlinearmatmul_2D_uint8_float16_cpu') backend_test.exclude('test_qlinearmatmul_3D_uint8_float16_cpu') diff --git a/test/test_ops.py b/test/test_ops.py index f635074b54..89b71912db 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1171,6 +1171,8 @@ class TestOps(unittest.TestCase): @slow_test def test_einsum(self): + # scalar + helper_test_op([()], lambda a: torch.einsum('->', a), lambda a: Tensor.einsum('->', a)) # matrix transpose helper_test_op([(10,10)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) helper_test_op([(10,10)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a)) @@ -1239,6 +1241,18 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError) + def test_einsum_trace(self): + # inner product + helper_test_op([(5,), (5,)], lambda a, b: torch.einsum('i,i', a, b), lambda a, b: Tensor.einsum('i,i', a, b)) + # simple diagonal + helper_test_op([(4, 4)], lambda a: torch.einsum('ii->i', a), lambda a: Tensor.einsum('ii->i', a)) + # trace (sum of diagonal) + helper_test_op([(4, 4)], lambda a: torch.einsum('ii->', a), lambda a: Tensor.einsum('ii->', a)) + # batch diagonal + helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...i', a), lambda a: Tensor.einsum('...ii->...i', a)) + # batch trace + helper_test_op([(3, 5, 5)], lambda a: torch.einsum('...ii->...', a), lambda a: Tensor.einsum('...ii->...', a)) + def test_einsum_shape_check(self): self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]), lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f00b598d6a..7f40504188 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2078,14 +2078,24 @@ class Tensor(OpMixin): inputs = inputs_str.split(",") if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}") - # map the value of each letter in the formula - letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items()) + # handle trace (repeated letter in single input means take diagonal) + xs_:list[Tensor] = list(xs) + for i, (letters, x) in enumerate(zip(inputs, xs)): + for c in set(letters): + while (idxs := [j for j, ch in enumerate(letters) if ch == c]) and len(idxs) > 1: + d0, d1, n = idxs[0], idxs[1], cast(int, x.shape[idxs[0]]) + perm = [j for j in range(x.ndim) if j not in (d0, d1)] + [d0, d1] + x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1, (n, n+1))[..., 0] if x.ndim > 2 else x.diagonal() + letters = letters[:d1] + letters[d1+1:] + inputs[i], xs_[i] = letters, x + + # map the value of each letter in the formula + letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs_)]).items()) - xs_:list[Tensor] = [] lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs] - for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]): - # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters - xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) + # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters + xs_ = [x.permute(o).reshape([v if l in letters else 1 for l,v in letter_val]).expand([v for _,v in letter_val]) + for x,(o,letters) in zip(xs_, [list(zip(*l)) if l else ((), ()) for l in lhs])] # ordinal encode the output alphabet rhs_order = argsort(argsort(list(output)))