mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
simpler einsum (#14014)
This commit is contained in:
@@ -2062,47 +2062,33 @@ class Tensor(OpMixin):
|
|||||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def parse_formula(formula:str, *operands:Tensor):
|
xs, formula = list(argfix(*operands)), formula.replace(" ", "")
|
||||||
if "..." in (formula := formula.replace(" ", "")):
|
# expand ellipsis to letters, determine output
|
||||||
ell_chars, ell_longest = "".join(c for c in string.ascii_letters if c not in formula), 0
|
if "..." in formula:
|
||||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
ell, lhs = "".join(c for c in string.ascii_letters if c not in formula), (formula.split("->") + [""])[0]
|
||||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count
|
ell_n = [max(0, x.ndim - len(s) + 3) if "..." in s else 0 for s, x in zip(lhs.split(","), xs)]
|
||||||
inputs[i] = inp.replace("...", ell_chars[-ell_count:])
|
for i, (s, x) in enumerate(zip(inputs := lhs.split(","), xs)): inputs[i] = s.replace("...", ell[max(ell_n)-ell_n[i]:max(ell_n)])
|
||||||
inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:]
|
lhs, auto = ",".join(inputs), "".join(sorted(c for c in lhs if lhs.count(c) == 1 and c.isalpha() and c not in ell))
|
||||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \
|
formula = f"{lhs}->{formula.split('->')[1].replace('...', ell[:max(ell_n)]) if '->' in formula else ell[:max(ell_n)] + auto}"
|
||||||
(inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
lhs, rhs = formula.split("->") if "->" in formula else (formula, "".join(sorted(c for c in formula if formula.count(c)==1 and c.isalpha())))
|
||||||
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
inputs = lhs.split(",")
|
||||||
|
if len(xs) != len(inputs): raise ValueError(f"number of operands doesn't match, expected {len(inputs)}, got {len(xs)}")
|
||||||
xs:tuple[Tensor, ...] = argfix(*operands)
|
# trace: take diagonal when letter repeats in single input
|
||||||
inputs_str, output = parse_formula(formula, *xs)
|
for i, (s, x) in enumerate(zip(inputs, xs)):
|
||||||
inputs = inputs_str.split(",")
|
for c in set(s):
|
||||||
if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}")
|
while s.count(c) > 1:
|
||||||
|
j, k, n = s.index(c), s.index(c, s.index(c)+1), cast(int, x.shape[s.index(c)])
|
||||||
# handle trace (repeated letter in single input means take diagonal)
|
perm = [d for d in range(x.ndim) if d not in (j,k)]+[j,k]
|
||||||
xs_:list[Tensor] = list(xs)
|
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1,(n,n+1))[...,0] if x.ndim > 2 else x.diagonal()
|
||||||
for i, (letters, x) in enumerate(zip(inputs, xs)):
|
s = s[:k] + s[k+1:]
|
||||||
for c in set(letters):
|
inputs[i], xs[i] = s, x
|
||||||
while (idxs := [j for j, ch in enumerate(letters) if ch == c]) and len(idxs) > 1:
|
# check sizes and build sorted alphabet
|
||||||
d0, d1, n = idxs[0], idxs[1], cast(int, x.shape[idxs[0]])
|
sz = merge_dicts([dict(zip(s, x.shape)) for s, x in zip(inputs, xs)])
|
||||||
perm = [j for j in range(x.ndim) if j not in (d0, d1)] + [d0, d1]
|
alpha = sorted(sz)
|
||||||
x = x.permute(perm).flatten(-2).pad(((0,0),)*(x.ndim-2)+((0,n),)).unflatten(-1, (n, n+1))[..., 0] if x.ndim > 2 else x.diagonal()
|
# align all tensors to alphabet, multiply, sum non-output, permute to output order
|
||||||
letters = letters[:d1] + letters[d1+1:]
|
xs = [x.permute(*[s.index(c) for c in sorted(s)]).reshape([sz[c] if c in s else 1 for c in alpha]).expand([sz[c] for c in alpha]) if s else x
|
||||||
inputs[i], xs_[i] = letters, x
|
for s, x in zip(inputs, xs)]
|
||||||
|
return functools.reduce(lambda a,b:a*b, xs).sum([i for i,c in enumerate(alpha) if c not in rhs], dtype=dtype).permute(argsort(argsort(list(rhs))))
|
||||||
# map the value of each letter in the formula
|
|
||||||
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs_)]).items())
|
|
||||||
|
|
||||||
lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs]
|
|
||||||
# permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters
|
|
||||||
xs_ = [x.permute(o).reshape([v if l in letters else 1 for l,v in letter_val]).expand([v for _,v in letter_val])
|
|
||||||
for x,(o,letters) in zip(xs_, [list(zip(*l)) if l else ((), ()) for l in lhs])]
|
|
||||||
|
|
||||||
# ordinal encode the output alphabet
|
|
||||||
rhs_order = argsort(argsort(list(output)))
|
|
||||||
|
|
||||||
# sum over all axes that's not in the output, then permute to the output order
|
|
||||||
return functools.reduce(lambda a,b:a*b, xs_) \
|
|
||||||
.sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], dtype=dtype).permute(rhs_order)
|
|
||||||
|
|
||||||
# ***** processing ops *****
|
# ***** processing ops *****
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user