simpler reshape symbolic shape check [pr] (#7837)

This commit is contained in:
chenyu
2024-11-21 22:53:57 -05:00
committed by GitHub
parent 1d6d842887
commit 6229d87f45

View File

@@ -310,8 +310,7 @@ class View:
# check for the same size
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if resolve(prod(self.shape) != prod(new_shape), False):
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
@@ -322,11 +321,8 @@ class View:
if self_all_int and not all_int(new_shape):
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
for si, so in zip(self.shape, new_shape):
if isinstance(so, int):
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
else:
var_vals = dict([v.unbind() for v in so.vars()])
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
# all dimensions matched, return the new view directly
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)