mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Einsum ellipsis support (#6333)
* working ellipsis expansion * refactor * fix commas in output * add capital letters * refactor
This commit is contained in:
@@ -756,10 +756,8 @@ class TestOps(unittest.TestCase):
|
||||
# bilinear transformation
|
||||
helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c]))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_einsum_ellipsis(self):
|
||||
"""The expected behavior for einsum is described in the PyTorch docs: https://pytorch.org/docs/stable/generated/torch.einsum.html"""
|
||||
# TODO: implement ellipsis support in einsum to pass these tests
|
||||
# test ellipsis
|
||||
helper_test_op([(3, 8, 9), (3, 8, 9)], lambda a, b: torch.einsum('...id, ...jd -> ...ij', [a, b]),
|
||||
lambda a, b: Tensor.einsum('...id, ...jd -> ...ij', [a, b]))
|
||||
@@ -769,6 +767,9 @@ class TestOps(unittest.TestCase):
|
||||
# multiple ellipsis in different operands with different shapes are allowed
|
||||
helper_test_op([(2, 3, 4, 5), (5, 2, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
|
||||
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]))
|
||||
# match torch ellipsis handling
|
||||
helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]),
|
||||
lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b]))
|
||||
# multiple ellipsis in one operand are not allowed. This test shall raise an exception.
|
||||
with self.assertRaises(RuntimeError):
|
||||
helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
||||
from collections import defaultdict
|
||||
@@ -1745,11 +1745,20 @@ class Tensor:
|
||||
print(Tensor.einsum("ij,ij->", x, y).numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula: str, *operands: Tensor):
|
||||
if "." in formula:
|
||||
ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0
|
||||
for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))):
|
||||
if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count
|
||||
inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:])
|
||||
inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:]
|
||||
return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \
|
||||
out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse)))
|
||||
return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
||||
|
||||
xs:Tuple[Tensor] = argfix(*raw_xs)
|
||||
formula = formula.replace(" ", "")
|
||||
inputs_str, output = formula.split("->") if "->" in formula else (formula, \
|
||||
''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha()))
|
||||
inputs = inputs_str.split(',')
|
||||
inputs_str, output = parse_formula(formula.replace(" ", ""), *xs)
|
||||
inputs = inputs_str.split(",")
|
||||
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
||||
|
||||
# map the value of each letter in the formula
|
||||
|
||||
Reference in New Issue
Block a user