mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
minimize view _reshape_mask API [run_process_replay] (#5359)
* minimize view _reshape_mask API [run_process_replay] _reshape_mask is only determined by mask, old_shape, new_shape. it does not need to input the whole view * combine
This commit is contained in:
@@ -35,15 +35,17 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
|
||||
return tuple(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Optional[Tuple[Tuple[sint, sint], ...]]:
|
||||
def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) \
|
||||
-> Optional[Tuple[Tuple[sint, sint], ...]]:
|
||||
"""Returns the new mask if reshape is possible, and None if not possible."""
|
||||
if view.mask is None: return tuple((0, s) for s in new_shape)
|
||||
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return None
|
||||
new_mask: List[Tuple[int, int]] = []
|
||||
if _mask is None: return tuple((0, s) for s in new_shape)
|
||||
if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in _mask): return None
|
||||
if any(m[1] - m[0] < 1 for m in _mask): return ((0, 0),) * len(new_shape) # zero mask
|
||||
|
||||
r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape)
|
||||
new_mask: List[Tuple[int, int]] = []
|
||||
# _mask is all int here
|
||||
r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int], ...], _mask)), reversed(old_shape), reversed(new_shape)
|
||||
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape) # invalid mask
|
||||
|
||||
while len(new_mask) < len(new_shape):
|
||||
(l, r), next_stride = mask, new_dim * curr_stride
|
||||
@@ -52,7 +54,6 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Optional[Tuple[Tupl
|
||||
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
||||
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
||||
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape) # invalid mask
|
||||
|
||||
else: # mask can only be splitted if reshape doesn't cut across the mask.
|
||||
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
||||
@@ -302,7 +303,7 @@ class View:
|
||||
if acc != merged_dim: break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
new_mask = _reshape_mask(self, new_shape)
|
||||
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
|
||||
if new_mask is not None:
|
||||
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
|
||||
Reference in New Issue
Block a user