mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
perf: avoid reshaping if not necessary (#1683)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -228,6 +228,7 @@ class ShapeTracker:
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape: Tuple[Union[Node,int], ...]):
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int))
|
||||
if new_nodes and all(isinstance(s, int) for s in self.shape):
|
||||
# reshape from all int shape into shape with a variable, update the variable value
|
||||
@@ -238,7 +239,6 @@ class ShapeTracker:
|
||||
self.var_vals[new_var] = new_val
|
||||
else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}"
|
||||
elif not new_nodes: self.var_vals = {}
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
|
||||
|
||||
Reference in New Issue
Block a user