Einsum ellipsis support (#6333)

* working ellipsis expansion

* refactor

* fix commas in output

* add capital letters

* refactor
This commit is contained in:
Oleg Rybalko
2024-09-05 05:08:55 +03:00
committed by GitHub
parent dde7a0d79c
commit 64f1384f5b
2 changed files with 17 additions and 7 deletions

View File

@@ -756,10 +756,8 @@ class TestOps(unittest.TestCase):
# bilinear transformation
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
@unittest.expectedFailure
def test_einsum_ellipsis(self):
"""The expected behavior for einsum is described in the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.einsum.html"""
# TODO: implement ellipsis support in einsum to pass these tests
# test ellipsis
helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd -> ...ij', [a, b]),
lambda a, b: Tensor.einsum('...id, ...jd -> ...ij', [a, b]))
@@ -769,6 +767,9 @@ class TestOps(unittest.TestCase):
# multiple ellipsis in different operands with different shapes are allowed
helper_test_op([(2, 3, 4, 5), (5, 2, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]))
# match torch ellipsis handling
helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]),
lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b]))
# multiple ellipsis in one operand are not allowed. This test shall raise an exception.
with self.assertRaises(RuntimeError):
helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]),

View File

@@ -1,7 +1,7 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import dataclasses
import time, math, itertools, functools, struct, sys, inspect, pathlib
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
@@ -1745,11 +1745,20 @@ class Tensor:
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
def parse_formula(formula: str, *operands: Tensor):
if "." in formula:
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count
inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:])
inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:]
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \
out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
xs:Tuple[Tensor] = argfix(*raw_xs)
formula = formula.replace(" ", "")
inputs_str, output = formula.split("->") if "->" in formula else (formula, \
''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
inputs = inputs_str.split(',')
inputs_str, output = parse_formula(formula.replace(" ", ""), *xs)
inputs = inputs_str.split(",")
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
# map the value of each letter in the formula