mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Fix the result permutation in einsum (#3895)
* Fix permutation of result indices in einsum. * Delete stray line used for breaking tests * Fix linter error by renaming twice-used variable --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
committed by
GitHub
parent
4e18dd78d3
commit
556dcfb8f2
@@ -576,6 +576,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b))
|
||||
# matrix-matrix multiplication
|
||||
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b))
|
||||
# matrix-matrix multiplication, different letter order
|
||||
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b))
|
||||
# dot product
|
||||
helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b]))
|
||||
# hadamard product
|
||||
@@ -588,12 +590,17 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b]))
|
||||
# batch matrix multiplication, result & input permuted
|
||||
helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
|
||||
# batch matrix multiplication, result with different letters
|
||||
helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b]))
|
||||
# tensor contraction
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
|
||||
# tensor contraction, input permuted
|
||||
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
|
||||
# tensor contraction, result with different letters
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5)
|
||||
# bilinear transformation
|
||||
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
|
||||
|
||||
|
||||
@@ -660,7 +660,11 @@ class Tensor:
|
||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
||||
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
||||
|
||||
rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
|
||||
# Determine the inverse permutation to revert back to original order
|
||||
rhs_order_sorted, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], [])
|
||||
rhs_order:List[int] = [0]*len(rhs_order_sorted)
|
||||
for sorted_idx,orig_idx in enumerate(rhs_order_sorted): rhs_order[orig_idx] = sorted_idx
|
||||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order)
|
||||
|
||||
Reference in New Issue
Block a user