From a77ee72d1103ed76a29bad8a9b890768d95149e2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 6 Dec 2024 07:51:19 -0500 Subject: [PATCH] clean up reshape size check [pr] (#8067) removed a resolve, and remove special case for 0 size assert since it's covered by generic size check --- test/test_tensor.py | 2 +- tinygrad/shape/view.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 34eb7fa697..59f188eed5 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -573,7 +573,7 @@ class TestZeroShapeTensor(unittest.TestCase): a = t.reshape(0) assert a.shape == (0,) np.testing.assert_equal(a.numpy(), np.zeros((0,))) - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): # cannot reshape from size 0 to size 1 a = t.reshape(()) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index d14eac4d0d..1a93c867fa 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -298,16 +298,13 @@ class View: def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]: if self.shape == new_shape: return self - # TODO: this resolve shouldn't be needed - assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}" - if 0 in self.shape: - assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" - return View.create(new_shape) + assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}" # check for the same size if (self_all_int := all_int(self.shape)): assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if 0 in self.shape: return View.create(new_shape) if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None # after the asserts, it's okay to check contiguous