mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
clean up reshape size check [pr] (#8067)
removed a resolve, and remove special case for 0 size assert since it's covered by generic size check
This commit is contained in:
@@ -573,7 +573,7 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
a = t.reshape(0)
|
||||
assert a.shape == (0,)
|
||||
np.testing.assert_equal(a.numpy(), np.zeros((0,)))
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# cannot reshape from size 0 to size 1
|
||||
a = t.reshape(())
|
||||
|
||||
|
||||
@@ -298,16 +298,13 @@ class View:
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
# TODO: this resolve shouldn't be needed
|
||||
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
||||
if 0 in self.shape:
|
||||
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
||||
return View.create(new_shape)
|
||||
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
||||
# check for the same size
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
|
||||
Reference in New Issue
Block a user