mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
expand ShapeTracker.invert a bit (#4864)
removed a type cast and it can early return now [run_process_replay]
This commit is contained in:
@@ -25,8 +25,11 @@ class ShapeTracker:
|
||||
return ret
|
||||
|
||||
def invert(self, out_shape:Tuple[sint, ...]) -> Optional[ShapeTracker]:
|
||||
ret = tuple(v.invert(s) for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]))
|
||||
return ShapeTracker(cast(Tuple[View, ...], ret)).reshape(out_shape) if all(x is not None for x in ret) else None
|
||||
inverted_views:List[View] = []
|
||||
for v,s in zip(self.views[::-1], [x.shape for x in self.views[::-1][1:]]+[out_shape]):
|
||||
if (inverted:= v.invert(s)) is None: return None
|
||||
inverted_views.append(inverted)
|
||||
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
|
||||
Reference in New Issue
Block a user