From e3081355feaa18e3fcad8369e78b4374a744d741 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 17 Nov 2024 16:11:30 -0500 Subject: [PATCH] minor Tensor.einsum cleanup (#7752) removed some dead conditions and add types. still reads more complicated than needed --- tinygrad/tensor.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2f8a155f76..32e808a58c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1936,7 +1936,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] Tensor: + def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor: """ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. @@ -1948,19 +1948,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(Tensor.einsum("ij,ij->", x, y).numpy()) ``` """ - def parse_formula(formula: str, *operands: Tensor): - if "." in formula: + def parse_formula(formula:str, *operands:Tensor): + if "..." in (formula := formula.replace(" ", "")): ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0 for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))): - if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count - inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:]) - inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:] - return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \ - out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) + if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count + inputs[i] = inp.replace("...", ell_chars[-ell_count:]) + inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:] + return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \ + (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha())) - xs:Tuple[Tensor, ...] = argfix(*raw_xs) - inputs_str, output = parse_formula(formula.replace(" ", ""), *xs) + xs:Tuple[Tensor, ...] = argfix(*operands) + inputs_str, output = parse_formula(formula, *xs) inputs = inputs_str.split(",") assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" @@ -1973,13 +1973,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) - # determine the inverse permutation to revert back to original order - rhs_letter_order = argsort(list(output)) - rhs_order = argsort(rhs_letter_order) + # ordinal encode the output alphabet + rhs_order = argsort(argsort(list(output))) # sum over all axes that's not in the output, then permute to the output order return functools.reduce(lambda a,b:a*b, xs_) \ - .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order) + .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order) # ***** processing ops *****