rhs_order in einsum is argsort twice (#3990)

* rhs_order in einsum is argsort twice

* comment
This commit is contained in:
chenyu
2024-03-29 11:42:04 -04:00
committed by GitHub
parent 7bc560ec49
commit 4abb8245a6

View File

@@ -7,7 +7,7 @@ from collections import defaultdict
import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv, argsort
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
@@ -670,10 +670,9 @@ class Tensor:
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
# Determine the inverse permutation to revert back to original order
rhs_letter_order = [idx for idx,_ in sorted(enumerate(output), key=lambda e:e[1])]
rhs_order:List[int] = [0]*len(rhs_letter_order)
for sorted_idx,orig_idx in enumerate(rhs_letter_order): rhs_order[orig_idx] = sorted_idx
# determine the inverse permutation to revert back to original order
rhs_letter_order = argsort(list(output))
rhs_order = argsort(rhs_letter_order)
# sum over all axes that's not in the output, then permute to the output order
return functools.reduce(lambda a,b:a*b, xs_) \