mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
minor Tensor.einsum cleanup (#7752)
removed some dead conditions and add types. still reads more complicated than needed
This commit is contained in:
@@ -1936,7 +1936,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
|
||||
@staticmethod
|
||||
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
"""
|
||||
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
|
||||
|
||||
@@ -1948,19 +1948,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula: str, *operands: Tensor):
|
||||
if "." in formula:
|
||||
def parse_formula(formula:str, *operands:Tensor):
|
||||
if "..." in (formula := formula.replace(" ", "")):
|
||||
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
|
||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:])
|
||||
inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:]
|
||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \
|
||||
out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
||||
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
|
||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
|
||||
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
||||
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
||||
|
||||
xs:Tuple[Tensor, ...] = argfix(*raw_xs)
|
||||
inputs_str, output = parse_formula(formula.replace(" ", ""), *xs)
|
||||
xs:Tuple[Tensor, ...] = argfix(*operands)
|
||||
inputs_str, output = parse_formula(formula, *xs)
|
||||
inputs = inputs_str.split(",")
|
||||
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
||||
|
||||
@@ -1973,13 +1973,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
||||
xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val]))
|
||||
|
||||
# determine the inverse permutation to revert back to original order
|
||||
rhs_letter_order = argsort(list(output))
|
||||
rhs_order = argsort(rhs_letter_order)
|
||||
# ordinal encode the output alphabet
|
||||
rhs_order = argsort(argsort(list(output)))
|
||||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order)
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order)
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user