mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
simpler einsum (#14014)
This commit is contained in:
@@ -2062,47 +2062,33 @@ class Tensor(OpMixin):
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula:str, *operands:Tensor):
|
||||
if "..." in (formula := formula.replace(" ", "")):
|
||||
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
|
||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
||||
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
|
||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
|
||||
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
||||
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
||||
|
||||
xs:tuple[Tensor, ...] = argfix(*operands)
|
||||
inputs_str, output = parse_formula(formula, *xs)
|
||||
inputs = inputs_str.split(",")
|
||||
if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}")
|
||||
|
||||
# handle trace (repeated letter in single input means take diagonal)
|
||||
xs_:list[Tensor] = list(xs)
|
||||
for i, (letters, x) in enumerate(zip(inputs, xs)):
|
||||
for c in set(letters):
|
||||
while (idxs := [j for j, ch in enumerate(letters) if ch == c]) and len(idxs) > 1:
|
||||
d0, d1, n = idxs[0], idxs[1], cast(int, x.shape[idxs[0]])
|
||||
perm = [j for j in range(x.ndim) if j not in (d0, d1)] + [d0, d1]
|
||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1, (n, n+1))[..., 0] if x.ndim > 2 else x.diagonal()
|
||||
letters = letters[:d1] + letters[d1+1:]
|
||||
inputs[i], xs_[i] = letters, x
|
||||
|
||||
# map the value of each letter in the formula
|
||||
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs_)]).items())
|
||||
|
||||
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
||||
xs_ = [x.permute(o).reshape([v if l in letters else 1 for l,v in letter_val]).expand([v for _,v in letter_val])
|
||||
for x,(o,letters) in zip(xs_, [list(zip(*l)) if l else ((), ()) for l in lhs])]
|
||||
|
||||
# ordinal encode the output alphabet
|
||||
rhs_order = argsort(argsort(list(output)))
|
||||
|
||||
# sum over all axes that's not in the output, then permute to the output order
|
||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
|
||||
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||
# expand ellipsis to letters, determine output
|
||||
if "..." in formula:
|
||||
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||
inputs = lhs.split(",")
|
||||
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||
# trace: take diagonal when letter repeats in single input
|
||||
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||
for c in set(s):
|
||||
while s.count(c) > 1:
|
||||
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||
s = s[:k] + s[k+1:]
|
||||
inputs[i], xs[i] = s, x
|
||||
# check sizes and build sorted alphabet
|
||||
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||
alpha = sorted(sz)
|
||||
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||
for s, x in zip(inputs, xs)]
|
||||
return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||
|
||||
# ***** processing ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user