Fix the result permutation in einsum (#3895)

* Fix permutation of result indices in einsum.

* Delete stray line used for breaking tests

* Fix linter error by renaming twice-used variable

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Alejandro F Queiruga
2024-03-23 15:48:19 -04:00
committed by GitHub
parent 4e18dd78d3
commit 556dcfb8f2
2 changed files with 12 additions and 1 deletions

View File

@@ -576,6 +576,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b))
# matrix-matrix multiplication
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b))
# matrix-matrix multiplication, different letter order
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b))
# dot product
helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b]))
# hadamard product
@@ -588,12 +590,17 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b]))
# batch matrix multiplication, result & input permuted
helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
# batch matrix multiplication, result with different letters
helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b]))
# tensor contraction
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
# tensor contraction, input permuted
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
# tensor contraction, result with different letters
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b),
lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5)
# bilinear transformation
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))

View File

@@ -660,7 +660,11 @@ class Tensor:
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
# Determine the inverse permutation to revert back to original order
rhs_order_sorted, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
rhs_order:List[int] = [0]*len(rhs_order_sorted)
for sorted_idx,orig_idx in enumerate(rhs_order_sorted): rhs_order[orig_idx] = sorted_idx
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)