mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
simpler reshape symbolic shape check [pr] (#7837)
This commit is contained in:
@@ -310,8 +310,7 @@ class View:
|
||||
# check for the same size
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if resolve(prod(self.shape) != prod(new_shape), False):
|
||||
raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
@@ -322,11 +321,8 @@ class View:
|
||||
if self_all_int and not all_int(new_shape):
|
||||
if len(self.shape) != len(new_shape): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
for si, so in zip(self.shape, new_shape):
|
||||
if isinstance(so, int):
|
||||
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
else:
|
||||
var_vals = dict([v.unbind() for v in so.vars()])
|
||||
if si != sym_infer(so, var_vals): raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
if not isinstance(so, int): so = sym_infer(so, dict([v.unbind() for v in so.vars()]))
|
||||
if si != so: raise ValueError(f"cannot symbolic reshape non-contiguous {self} -> {new_shape}")
|
||||
# all dimensions matched, return the new view directly
|
||||
return View(new_shape, self.strides, self.offset, self.mask, self.contiguous)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user