diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index eb985854c9..39b4ee084a 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -44,7 +44,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl while len(new_mask) < len(new_shape): (l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride - if old_dim >= new_dim: # need to split mask. + if old_dim >= new_dim * curr_stride: # need to split mask. offsets.append(off) if old_dim == next_stride: # simply copy the mask and get next batch for merging @@ -57,7 +57,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl new_mask.append((l % ns // curr_stride, (r - 1) % ns // curr_stride + 1)) curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension - elif old_dim < new_dim * curr_stride: + else: next_mask = next(r_masks, (0, 1)) # combine if the mask can unfold continuously if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True