From 6229d87f458f1bcdbde0bd60ba00b8d8b96e3fa1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 21 Nov 2024 22:53:57 -0500 Subject: [PATCH] simpler reshape symbolic shape check [pr] (#7837) --- tinygrad/shape/view.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index a3d34e3b4b..80fb61da86 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -310,8 +310,7 @@ class View: # check for the same size if (self_all_int := all_int(self.shape)): assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" - if resolve(prod(self.shape) != prod(new_shape), False): - raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") + if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}") if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None @@ -322,11 +321,8 @@ class View: if self_all_int and not all_int(new_shape): if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") for si, so in zip(self.shape, new_shape): - if isinstance(so, int): - if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") - else: - var_vals = dict([v.unbind() for v in so.vars()]) - if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") + if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()])) + if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}") # all dimensions matched, return the new view directly return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)