diff --git a/test/test_ops.py b/test/test_ops.py index 48e2d3fc65..96bb54617e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1209,32 +1209,24 @@ class TestOps(unittest.TestCase): # match torch ellipsis handling helper_test_op([(32, 7, 24, 24, 24), (32, 7, 24, 24, 24)], lambda a, b: torch.einsum('ij...,ij...->ij', [a, b]), lambda a, b: Tensor.einsum('ij...,ij...->ij', [a, b])) - # multiple ellipsis in one operand are not allowed. This test shall raise an exception. - with self.assertRaises(RuntimeError): - helper_test_op([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]), - lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b])) - # multiple ellipsis must broadcast together. This test shall raise an exception. - with self.assertRaises(RuntimeError): - helper_test_op([(2, 3, 4, 5), (5, 2, 7)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), - lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b])) + # multiple ellipsis in one operand are not allowed + self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('...ik..., ...jk ->', [a, b]), + lambda a, b: Tensor.einsum('...ik..., ...jk ->', [a, b]), expected=(RuntimeError, IndexError)) + # multiple ellipsis must broadcast together + self.helper_test_exception([(2, 3, 4), (2, 3, 4)], lambda a, b: torch.einsum('i...j,ji...->...', [a, b]), + lambda a, b: Tensor.einsum('i...j,ji...->...', [a, b]), expected=RuntimeError) def test_einsum_shape_check(self): - a = Tensor.zeros(3,8,10,5) - b = Tensor.zeros(11,5,13,16,8) - with self.assertRaises(RuntimeError): - Tensor.einsum('pqrs,tuqvr->pstuv',a,b) + self.helper_test_exception([(3,8,10,5), (11,5,13,16,8)], lambda a, b: torch.einsum('pqrs,tuqvr->pstuv', [a, b]), + lambda a, b: Tensor.einsum('pqrs,tuqvr->pstuv', [a, b]), expected=RuntimeError) def test_einsum_arity_check1(self): - a = Tensor.zeros(10,15) - b = Tensor.zeros(15,20) - c = Tensor.zeros(20,10) - with self.assertRaises(AssertionError): - Tensor.einsum('ij,jk->ij', a,b,c) + self.helper_test_exception([(10,15), (15,20), (20,10)], lambda a, b, c: torch.einsum('ij,jk->ij', [a, b, c]), + lambda a, b, c: Tensor.einsum('ij,jk->ij', [a, b, c]), expected=(ValueError, RuntimeError)) def test_einsum_arity_check2(self): - a = Tensor.zeros(10,10) - with self.assertRaises(AssertionError): - Tensor.einsum('ij,jk->ij', a) + self.helper_test_exception([(10,10)], lambda a: torch.einsum('ij,jk->ij', a), + lambda a: Tensor.einsum('ij,jk->ij', a), expected=(ValueError, RuntimeError)) @unittest.skipIf(IMAGE>0, "no 1d dot for images") def test_dot_1d(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index bdfd18ea55..161ae179e1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2255,7 +2255,7 @@ class Tensor(MathTrait): xs:tuple[Tensor, ...] = argfix(*operands) inputs_str, output = parse_formula(formula, *xs) inputs = inputs_str.split(",") - assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" + if len(xs)!=len(inputs): raise ValueError(f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}") # map the value of each letter in the formula letter_val = sorted(merge_dicts([dict(zip(letters, tensor.shape)) for letters, tensor in zip(inputs, xs)]).items())