mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
rhs_order in einsum is argsort twice (#3990)
* rhs_order in einsum is argsort twice * comment
This commit is contained in:
@@ -7,7 +7,7 @@ from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv, argsort
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
@@ -670,10 +670,9 @@ class Tensor:
|
||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
||||
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
||||
|
||||
# Determine the inverse permutation to revert back to original order
|
||||
rhs_letter_order = [idx for idx,_ in sorted(enumerate(output), key=lambda e:e[1])]
|
||||
rhs_order:List[int] = [0]*len(rhs_letter_order)
|
||||
for sorted_idx,orig_idx in enumerate(rhs_letter_order): rhs_order[orig_idx] = sorted_idx
|
||||
# determine the inverse permutation to revert back to original order
|
||||
rhs_letter_order = argsort(list(output))
|
||||
rhs_order = argsort(rhs_letter_order)
|
||||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
|
||||
Reference in New Issue
Block a user