mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
fix for false merge (#2616)
This commit is contained in:
@@ -44,7 +44,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
||||
while len(new_mask) < len(new_shape):
|
||||
(l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride
|
||||
|
||||
if old_dim >= new_dim: # need to split mask.
|
||||
if old_dim >= new_dim * curr_stride: # need to split mask.
|
||||
offsets.append(off)
|
||||
|
||||
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
||||
@@ -57,7 +57,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl
|
||||
new_mask.append((l % ns // curr_stride, (r - 1) % ns // curr_stride + 1))
|
||||
curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension
|
||||
|
||||
elif old_dim < new_dim * curr_stride:
|
||||
else:
|
||||
next_mask = next(r_masks, (0, 1))
|
||||
# combine if the mask can unfold continuously
|
||||
if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True
|
||||
|
||||
Reference in New Issue
Block a user