diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 1569d85d61..dd4de448af 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -25,8 +25,11 @@ class ShapeTracker: return ret def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]: - ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape])) - return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None + inverted_views:List[View] = [] + for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]): + if (inverted:= v.invert(s)) is None: return None + inverted_views.append(inverted) + return ShapeTracker(tuple(inverted_views)).reshape(out_shape) @staticmethod def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))