diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7f40504188..ad26051b8e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2062,47 +2062,33 @@ class Tensor(OpMixin): print(Tensor.einsum("ij,ij->", x, y).numpy()) ``` """ - def parse_formula(formula:str, *operands:Tensor): - if "..." in (formula := formula.replace(" ", "")): - ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0 - for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))): - if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count - inputs[i] = inp.replace("...", ell_chars[-ell_count:]) - inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:] - return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \ - (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) - return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha())) - - xs:tuple[Tensor, ...] = argfix(*operands) - inputs_str, output = parse_formula(formula, *xs) - inputs = inputs_str.split(",") - if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}") - - # handle trace (repeated letter in single input means take diagonal) - xs_:list[Tensor] = list(xs) - for i, (letters, x) in enumerate(zip(inputs, xs)): - for c in set(letters): - while (idxs := [j for j, ch in enumerate(letters) if ch == c]) and len(idxs) > 1: - d0, d1, n = idxs[0], idxs[1], cast(int, x.shape[idxs[0]]) - perm = [j for j in range(x.ndim) if j not in (d0, d1)] + [d0, d1] - x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1, (n, n+1))[..., 0] if x.ndim > 2 else x.diagonal() - letters = letters[:d1] + letters[d1+1:] - inputs[i], xs_[i] = letters, x - - # map the value of each letter in the formula - letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs_)]).items()) - - lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs] - # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters - xs_ = [x.permute(o).reshape([v if l in letters else 1 for l,v in letter_val]).expand([v for _,v in letter_val]) - for x,(o,letters) in zip(xs_, [list(zip(*l)) if l else ((), ()) for l in lhs])] - - # ordinal encode the output alphabet - rhs_order = argsort(argsort(list(output))) - - # sum over all axes that's not in the output, then permute to the output order - return functools.reduce(lambda a,b:a*b, xs_) \ - .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order) + xs, formula = list(argfix(*operands)), formula.replace(" ", "") + # expand ellipsis to letters, determine output + if "..." in formula: + ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0] + ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)] + for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)]) + lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell)) + formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}" + lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha()))) + inputs = lhs.split(",") + if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}") + # trace: take diagonal when letter repeats in single input + for i, (s, x) in enumerate(zip(inputs, xs)): + for c in set(s): + while s.count(c) > 1: + j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)]) + perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k] + x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal() + s = s[:k] + s[k+1:] + inputs[i], xs[i] = s, x + # check sizes and build sorted alphabet + sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)]) + alpha = sorted(sz) + # align all tensors to alphabet, multiply, sum non-output, permute to output order + xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x + for s, x in zip(inputs, xs)] + return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs)))) # ***** processing ops *****