expand ShapeTracker.invert a bit (#4864)

removed a type cast and it can early return now

[run_process_replay]
This commit is contained in:
chenyu
2024-06-07 14:26:02 -04:00
committed by GitHub
parent 688b14c933
commit 3a20cff7c2

View File

@@ -25,8 +25,11 @@ class ShapeTracker:
return ret
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
inverted_views:List[View] = []
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
if (inverted:= v.invert(s)) is None: return None
inverted_views.append(inverted)
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))