diff --git a/test/test_ops.py b/test/test_ops.py index daf64b80d6..7043f57f8c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1123,6 +1123,8 @@ class TestOps(unittest.TestCase): helper_test_op([(1,)], lambda x: x.reshape([])) helper_test_op([()], lambda x: x.reshape([1])) helper_test_op([()], lambda x: x.reshape([1, 1, 1])) + self.helper_test_exception([(3, 4)], lambda x: x.reshape((-1, -1, 2)), lambda x: x.reshape((-1, -1, 2)), expected=RuntimeError) + self.helper_test_exception([(3, 4)], lambda x: x.reshape((-1, -1, -1, 2)), lambda x: x.reshape((-1, -1, -1, 2)), expected=RuntimeError) with self.assertRaises(ValueError): x = Tensor.ones((4,3,6,6)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 63891b7f1e..15cdda0da0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -765,6 +765,7 @@ class Tensor: ``` """ new_shape = argfix(shape, *args) + if new_shape.count(-1) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]) return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self