From a6b68e8e403ff15d634c447b8eb38d0ebde76bf4 Mon Sep 17 00:00:00 2001 From: Amrit Sahu <88420255+sahamrit@users.noreply.github.com> Date: Tue, 5 Dec 2023 21:17:18 +0530 Subject: [PATCH] fix for false merge (#2616) --- tinygrad/shape/view.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index eb985854c9..39b4ee084a 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -44,7 +44,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl while len(new_mask) < len(new_shape): (l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride - if old_dim >= new_dim: # need to split mask. + if old_dim >= new_dim * curr_stride: # need to split mask. offsets.append(off) if old_dim == next_stride: # simply copy the mask and get next batch for merging @@ -57,7 +57,7 @@ def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tupl new_mask.append((l % ns // curr_stride, (r - 1) % ns // curr_stride + 1)) curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension - elif old_dim < new_dim * curr_stride: + else: next_mask = next(r_masks, (0, 1)) # combine if the mask can unfold continuously if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True