perf: avoid reshaping if not necessary (#1683)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-08-27 20:17:04 +02:00
committed by GitHub
parent 328cf2e86a
commit b66f54e379

View File

@@ -228,6 +228,7 @@ class ShapeTracker:
return self
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
if self.views[-1].shape == new_shape: return self
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
if new_nodes and all(isinstance(s, int) for s in self.shape):
# reshape from all int shape into shape with a variable, update the variable value
@@ -238,7 +239,6 @@ class ShapeTracker:
self.var_vals[new_var] = new_val
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
elif not new_nodes: self.var_vals = {}
if self.views[-1].shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):