From 3a20cff7c244eb6cb8f9a4567c3c1cfd9cb43b80 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 7 Jun 2024 14:26:02 -0400 Subject: [PATCH] expand ShapeTracker.invert a bit (#4864) removed a type cast and it can early return now [run_process_replay] --- tinygrad/shape/shapetracker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 1569d85d61..dd4de448af 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -25,8 +25,11 @@ class ShapeTracker: return ret def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]: - ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape])) - return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None + inverted_views:List[View] = [] + for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]): + if (inverted:= v.invert(s)) is None: return None + inverted_views.append(inverted) + return ShapeTracker(tuple(inverted_views)).reshape(out_shape) @staticmethod def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))