From dae89768896ee079288c0f25946fbfc678b73e18 Mon Sep 17 00:00:00 2001 From: Peter Cawley Date: Wed, 20 Dec 2023 22:00:58 +0000 Subject: [PATCH] Fix reshape merging with masks (#2877) --- test/unit/test_shapetracker.py | 10 +-------- tinygrad/shape/view.py | 41 +++++++++++++++------------------- 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 5908fd7b6e..02103e93df 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -256,7 +256,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_reshape_combining_1(self): self.st = CheckingShapeTracker((2,1,10)) self.st.pad(((2,6), (0,0), (0,0))) @@ -264,7 +263,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_reshape_combining_2(self): self.st = CheckingShapeTracker((1,1,5)) self.st.pad(((3,6), (0,0), (0,5))) @@ -272,7 +270,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_reshape_combining_3(self): self.st = CheckingShapeTracker((1,1,4)) self.st.pad(((3,6), (0,0), (1,5))) @@ -281,7 +278,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert self.st.views[0].mask[0] == (31, 35) self.st.assert_same() - @unittest.expectedFailure def test_reshape_combining_4(self): # interestingly this one is quite slow self.st = CheckingShapeTracker((1,1,5,5,1,1,5)) @@ -290,7 +286,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_reshape_splitting_combining(self): self.st = CheckingShapeTracker((1,5,5)) self.st.pad(((0,4), (0,5), (0,0))) @@ -298,7 +293,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_reshape_only_1s(self): self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1)) self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0))) @@ -342,7 +336,6 @@ class TestIndexExpressions2d(unittest.TestCase): assert len(self.st.views) == 1 self.st.assert_same() - @unittest.expectedFailure def test_splitting_big(self): self.st = CheckingShapeTracker((1, 5, 1, 15, 1)) self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0))) @@ -351,10 +344,9 @@ class TestIndexExpressions2d(unittest.TestCase): self.st.reshape((2,3,5,2,5)) assert len(self.st.views) == 1 v = self.st.views[-1] - assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) + assert v.strides == (0, 5, 1, 0, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5)) self.st.assert_same() - @unittest.expectedFailure def test_combining_big(self): self.st = CheckingShapeTracker((1,3,1,5,3,1)) self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0))) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index f7ec6f0863..8f6b228501 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -33,45 +33,39 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu return tuple(ret) @functools.lru_cache(maxsize=None) -def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], Optional[Tuple[sint, ...]], bool]: - if view.mask is None: return view.mask, None, False - if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, None, True +def _reshape_mask(view: View, new_shape:Tuple[sint, ...]) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], bool]: + if view.mask is None: return view.mask, False + if any(not isinstance(m[0], int) or not isinstance(m[1], int) for m in view.mask): return view.mask, True new_mask: List[Tuple[int, int]] = [] r_masks, r_shape, r_new_shape = reversed(view.mask), reversed(view.shape), reversed(new_shape) - curr_stride, off, offsets, old_dim, new_dim, mask = 1, 0, [], next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - # off represents offset while combining masks of range one & zero stride - if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), None, False # invalid mask + curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) + if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask while len(new_mask) < len(new_shape): (l, r), next_stride = mask, new_dim * curr_stride if old_dim >= next_stride: # need to split mask. - offsets.append(off) - if old_dim == next_stride: # simply copy the mask and get next batch for merging new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1)) - curr_stride, off, old_dim, new_dim, mask = 1, 0, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) - if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), None, False # invalid mask + curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) + if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), False # invalid mask else: # mask can only be splitted if reshape doesn't cut across the mask. - if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride): return view.mask, None, True + if ((l % next_stride != 0 or r % next_stride != 0) and l // next_stride != (r - 1) // next_stride): return view.mask, True new_mask.append((l % next_stride // curr_stride, (r - 1) % next_stride // curr_stride + 1)) curr_stride, new_dim = next_stride, next(r_new_shape, 1) # need to get mask for next dimension else: - # TODO: fix this, it's incorrect - return view.mask, None, True - # next_mask = next(r_masks, (0, 1)) - # # combine if the mask can unfold continuously - # if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, None, True - # if next_mask != (0, 1) and mask != (0, 1) and (next_mask[1] - next_mask[0] == 1): off += next_mask[0] * old_dim - # mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1) + next_mask = next(r_masks, (0, 1)) + # combine if the mask can unfold continuously + if mask != (0, old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, True + mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1) for mask in r_masks: # if the old shape has leading 1s, need to make sure their mask is (0,1) - if mask != (0, 1): return ((0, 0),) * len(new_shape), None, False + if mask != (0, 1): return ((0, 0),) * len(new_shape), False # invalid mask - return tuple(reversed(new_mask)), tuple(offsets), False + return tuple(reversed(new_mask)), False @dataclass(frozen=True) class View: @@ -190,8 +184,9 @@ class View: if acc != merged_dim: break else: strides += [0,] * (len(new_shape) - len(strides)) - mask, off_mask, extra = _reshape_mask(self, new_shape) - total_offset = sum([off * s for off, s in zip(off_mask, strides)]) if off_mask else 0 - if not extra: return View.create(new_shape, tuple(reversed(strides)), self.offset - total_offset, mask) + mask, extra = _reshape_mask(self, new_shape) + fstrides = filter_strides(tuple(e-b for b,e in mask) if mask else new_shape, tuple(reversed(strides))) + extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - (sum(m[0] * s for m,s in zip(mask, fstrides)) if mask else 0) # noqa: E501 + if not extra: return View.create(new_shape, fstrides, self.offset + extra_offset, mask) return None