mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update many einsum tests (#11981)
correct the exception testing, and raise ValueError instead of assert when checking args
This commit is contained in:
@@ -1209,32 +1209,24 @@ class TestOps(unittest.TestCase):
|
||||
# match torch ellipsis handling
|
||||
helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]),
|
||||
lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b]))
|
||||
# multiple ellipsis in one operand are not allowed. This test shall raise an exception.
|
||||
with self.assertRaises(RuntimeError):
|
||||
helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]),
|
||||
lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]))
|
||||
# multiple ellipsis must broadcast together. This test shall raise an exception.
|
||||
with self.assertRaises(RuntimeError):
|
||||
helper_test_op([(2, 3, 4, 5), (5, 2, 7)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
|
||||
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]))
|
||||
# multiple ellipsis in one operand are not allowed
|
||||
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]),
|
||||
lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]), expected=(RuntimeError, IndexError))
|
||||
# multiple ellipsis must broadcast together
|
||||
self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]),
|
||||
lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError)
|
||||
|
||||
def test_einsum_shape_check(self):
|
||||
a = Tensor.zeros(3,8,10,5)
|
||||
b = Tensor.zeros(11,5,13,16,8)
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.einsum('pqrs,tuqvr->pstuv',a,b)
|
||||
self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]),
|
||||
lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError)
|
||||
|
||||
def test_einsum_arity_check1(self):
|
||||
a = Tensor.zeros(10,15)
|
||||
b = Tensor.zeros(15,20)
|
||||
c = Tensor.zeros(20,10)
|
||||
with self.assertRaises(AssertionError):
|
||||
Tensor.einsum('ij,jk->ij', a,b,c)
|
||||
self.helper_test_exception([(10,15), (15,20), (20,10)], lambda a, b, c: torch.einsum('ij,jk->ij', [a, b, c]),
|
||||
lambda a, b, c: Tensor.einsum('ij,jk->ij', [a, b, c]), expected=(ValueError, RuntimeError))
|
||||
|
||||
def test_einsum_arity_check2(self):
|
||||
a = Tensor.zeros(10,10)
|
||||
with self.assertRaises(AssertionError):
|
||||
Tensor.einsum('ij,jk->ij', a)
|
||||
self.helper_test_exception([(10,10)], lambda a: torch.einsum('ij,jk->ij', a),
|
||||
lambda a: Tensor.einsum('ij,jk->ij', a), expected=(ValueError, RuntimeError))
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
|
||||
def test_dot_1d(self):
|
||||
|
||||
@@ -2255,7 +2255,7 @@ class Tensor(MathTrait):
|
||||
xs:tuple[Tensor, ...] = argfix(*operands)
|
||||
inputs_str, output = parse_formula(formula, *xs)
|
||||
inputs = inputs_str.split(",")
|
||||
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
||||
if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}")
|
||||
|
||||
# map the value of each letter in the formula
|
||||
letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())
|
||||
|
||||
Reference in New Issue
Block a user