diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e0844fec63..3444008821 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from collections import defaultdict import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype -from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv +from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv, argsort from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer @@ -670,10 +670,9 @@ class Tensor: # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) - # Determine the inverse permutation to revert back to original order - rhs_letter_order = [idx for idx,_ in sorted(enumerate(output), key=lambda e:e[1])] - rhs_order:List[int] = [0]*len(rhs_letter_order) - for sorted_idx,orig_idx in enumerate(rhs_letter_order): rhs_order[orig_idx] = sorted_idx + # determine the inverse permutation to revert back to original order + rhs_letter_order = argsort(list(output)) + rhs_order = argsort(rhs_letter_order) # sum over all axes that's not in the output, then permute to the output order return functools.reduce(lambda a,b:a*b, xs_) \