mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Fix the result permutation in einsum (#3895)
* Fix permutation of result indices in einsum. * Delete stray line used for breaking tests * Fix linter error by renaming twice-used variable --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
committed by
GitHub
parent
4e18dd78d3
commit
556dcfb8f2
@@ -576,6 +576,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b))
|
||||
# matrix-matrix multiplication
|
||||
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b))
|
||||
# matrix-matrix multiplication, different letter order
|
||||
helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('jk,ki->ji', a,b), lambda a,b: Tensor.einsum('jk,ki->ji', a, b))
|
||||
# dot product
|
||||
helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b]))
|
||||
# hadamard product
|
||||
@@ -588,12 +590,17 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b]))
|
||||
# batch matrix multiplication, result & input permuted
|
||||
helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b]))
|
||||
# batch matrix multiplication, result with different letters
|
||||
helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ika->ija', [a, b]), lambda a,b: Tensor.einsum('ijk,ika->ija', [a, b]))
|
||||
# tensor contraction
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5)
|
||||
# tensor contraction, input permuted
|
||||
helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5)
|
||||
# tensor contraction, result with different letters
|
||||
helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('zqrs,tuqvr->zstuv', a,b),
|
||||
lambda a,b: Tensor.einsum('zqrs,tuqvr->zstuv', a,b), atol=1e-5)
|
||||
# bilinear transformation
|
||||
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user