fix for false merge (#2616)

This commit is contained in:
Amrit Sahu
2023-12-05 21:17:18 +05:30
committed by GitHub
parent fc00da538d
commit a6b68e8e40

View File

@@ -44,7 +44,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
while len(new_mask) < len(new_shape):
(l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride
if old_dim >= new_dim: # need to split mask.
if old_dim >= new_dim * curr_stride: # need to split mask.
offsets.append(off)
if old_dim == next_stride: # simply copy the mask and get next batch for merging
@@ -57,7 +57,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
new_mask.append((l % ns // curr_stride, (r - 1) % ns // curr_stride + 1))
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
elif old_dim < new_dim * curr_stride:
else:
next_mask = next(r_masks, (0, 1))
# combine if the mask can unfold continuously
if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True