diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index dfd66dc941..75a28601d6 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -228,6 +228,7 @@ class ShapeTracker: return self def reshape(self, new_shape: Tuple[Union[Node,int], ...]): + if self.views[-1].shape == new_shape: return self new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int)) if new_nodes and all(isinstance(s, int) for s in self.shape): # reshape from all int shape into shape with a variable, update the variable value @@ -238,7 +239,6 @@ class ShapeTracker: self.var_vals[new_var] = new_val else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" elif not new_nodes: self.var_vals = {} - if self.views[-1].shape == new_shape: return self assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}" # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):