remove a leading 1 check in _reshape_mask [pr] (#8327)

the only possible mask for it is either (0, 0) or (0, 1). so the logic is no-op
This commit is contained in:
chenyu
2024-12-18 19:30:10 -05:00
committed by GitHub
parent 8a8eaa1ed9
commit accc186c8b

View File

@@ -68,10 +68,6 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple
if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return None
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
# TODO: do we need this?
for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1)
if mask != (0, 1): return ((0, 0),) * len(new_shape) # invalid mask
return tuple(reversed(new_mask))
def unravel(shape:Tuple[sint, ...], offset:sint) -> List[sint]: