From b66f54e3791f7bb2bd00e3ee2a580a67f9e7a65f Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Sun, 27 Aug 2023 20:17:04 +0200 Subject: [PATCH] perf: avoid reshaping if not necessary (#1683) Co-authored-by: Roelof van Dijk --- tinygrad/shape/shapetracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index dfd66dc941..75a28601d6 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -228,6 +228,7 @@ class ShapeTracker: return self def reshape(self, new_shape: Tuple[Union[Node,int], ...]): + if self.views[-1].shape == new_shape: return self new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int)) if new_nodes and all(isinstance(s, int) for s in self.shape): # reshape from all int shape into shape with a variable, update the variable value @@ -238,7 +239,6 @@ class ShapeTracker: self.var_vals[new_var] = new_val else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" elif not new_nodes: self.var_vals = {} - if self.views[-1].shape == new_shape: return self assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}" # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):