mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 00:08:16 -05:00
minor change to _reshape_mask [pr] (#8324)
formatting before logic change
This commit is contained in:
@@ -43,7 +43,7 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
|
||||
-> Optional[Tuple[Tuple[sint, sint], ...]]:
|
||||
"""Returns the new mask if reshape is possible, and None if not possible."""
|
||||
if _mask is None: return tuple((0, s) for s in new_shape)
|
||||
if any(not all_int(m) for m in _mask): return None
|
||||
if not all_int(flatten(_mask)): return None
|
||||
|
||||
new_mask: List[Tuple[int, int]] = []
|
||||
# _mask is all int here
|
||||
@@ -58,17 +58,17 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
|
||||
new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1))
|
||||
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
elif old_dim > next_stride: # mask can only be splitted if reshape doesn't cut across the mask.
|
||||
if (((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride)
|
||||
or old_dim % next_stride != 0): return None
|
||||
if old_dim % next_stride != 0: return None
|
||||
if (l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride: return None
|
||||
new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1))
|
||||
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
||||
|
||||
else:
|
||||
next_mask = next(r_masks, (0, 1))
|
||||
# combine if the mask can unfold continuously
|
||||
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
|
||||
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
||||
|
||||
# TODO: do we need this?
|
||||
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
|
||||
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user