clean up reshape size check [pr] (#8067)

removed a resolve, and remove special case for 0 size assert since it's covered by generic size check
This commit is contained in:
chenyu
2024-12-06 07:51:19 -05:00
committed by GitHub
parent 074a67a6eb
commit a77ee72d11
2 changed files with 3 additions and 6 deletions

View File

@@ -573,7 +573,7 @@ class TestZeroShapeTensor(unittest.TestCase):
a = t.reshape(0)
assert a.shape == (0,)
np.testing.assert_equal(a.numpy(), np.zeros((0,)))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# cannot reshape from size 0 to size 1
a = t.reshape(())

View File

@@ -298,16 +298,13 @@ class View:
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self
# TODO: this resolve shouldn't be needed
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
if 0 in self.shape:
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
return View.create(new_shape)
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
# check for the same size
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if 0 in self.shape: return View.create(new_shape)
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
# after the asserts, it's okay to check contiguous