diff --git a/test/external/fuzz_shapetracker.py b/test/external/fuzz_shapetracker.py index 0a10d581ff..ba11fa1ebd 100644 --- a/test/external/fuzz_shapetracker.py +++ b/test/external/fuzz_shapetracker.py @@ -38,17 +38,10 @@ def do_shrink(st): if DEBUG >= 1: print("st.shrink(", shrink, ")") st.shrink(shrink) -def do_stride(st): - c = random.randint(0, len(st.shape)-1) - stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape))) - if DEBUG >= 1: print("st.stride(", stride, ")") - st.stride(stride) - def do_flip(st): - c = random.randint(0, len(st.shape)-1) - stride = tuple(-1 if i==c else 1 for i in range(len(st.shape))) - if DEBUG >= 1: print("st.stride(", stride, ")") - st.stride(stride) + flip = tuple(random.random() < 0.5 for _ in st.shape) + if DEBUG >= 1: print("st.flip(", flip, ")") + st.flip(flip) def do_expand(st): c = [i for i,s in enumerate(st.shape) if s==1] @@ -58,7 +51,7 @@ def do_expand(st): if DEBUG >= 1: print("st.expand(", expand, ")") st.expand(expand) -shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand] +shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_flip, do_expand] if __name__ == "__main__": random.seed(42) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 1a8c077f14..7257a6dd0a 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -43,9 +43,9 @@ class CheckingShapeTracker: self.t = np.broadcast_to(self.t, new_shape) return self - def flip(self, axis): - self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape)))) - self.t = np.flip(self.t, axis) + def flip(self, arg): + self.st = self.st.flip(arg) + self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i])) return self def shrink(self, arg): @@ -58,11 +58,6 @@ class CheckingShapeTracker: self.t = np.pad(self.t, arg, constant_values=-1) return self - def stride(self, arg): - self.st = self.st.stride(arg) - self.t = self.t[tuple([slice(None, None, x) for x in arg])] - return self - def __getitem__(self, val): return self.t.flatten()[val] @@ -579,20 +574,18 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase): self.st.shrink(((1, 2), (1, 3), (1, 3))) self.st.reshape((1, 4)) self.st.shrink(((0, 1), (1, 3))) - print(self.st.st) self.st = self.st.simplify() - print(self.st.st) def test_case_2(self): - self.st.stride( (1, 1, -2) ) - self.st.reshape( (3, 6) ) + self.st.flip( (True, False, True) ) + self.st.reshape( (3, 9) ) self.st.shrink( ((1, 2), (1, 5)) ) - self.st.stride( (1, -1) ) + self.st.flip( (True, True) ) def test_case_3(self): self.st.shrink( ((0, 2), (0, 2), (0, 1)) ) self.st.permute( (1, 0, 2) ) self.st.reshape( (4,) ) self.st.shrink( ((0, 3),) ) - self.st.stride( (-1,) ) + self.st.flip( (True, False) ) def test_case_4(self): self.st.reshape( (3, 3, 3, 1) ) self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) ) @@ -687,21 +680,13 @@ class TestShapeTracker(unittest.TestCase): self.st.reshape((9,6,1)) self.st.expand((9,6,4)) - def test_pad_stride(self): + def test_pad_flip(self): self.st.pad(((1,4), (1,3))) - self.st.stride((2,2)) + self.st.flip((True, False)) - def test_pad_stride_neg(self): - self.st.pad(((1,2), (1,0))) - self.st.stride((-1,-1)) - - def test_pad_stride_both(self): - self.st.pad(((1,2), (1,0))) - self.st.stride((-2,-2)) - - def test_shrink_pad(self): - self.st.shrink(((0,4), (0,4))) - self.st.pad(((1,1), (1,1))) + def test_pad_flip_int(self): + self.st.pad(((1,4), (1,3))) + self.st.flip((0, 1)) def test_reshape(self): new_shape = self.st.shape[::-1] @@ -722,13 +707,13 @@ class TestShapeTracker(unittest.TestCase): self.apply(lambda x: x.expand(tuple(new_shape))) def test_flip_0(self): - self.apply(lambda x: x.flip((0,))) + self.apply(lambda x: x.flip((True, False))) def test_flip_1(self): - self.apply(lambda x: x.flip((1,))) + self.apply(lambda x: x.flip((False, True))) def test_flip_01(self): - self.apply(lambda x: x.flip((0,1))) + self.apply(lambda x: x.flip((True, True))) def test_slice_0(self): self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1])))) @@ -754,16 +739,13 @@ class TestShapeTracker(unittest.TestCase): self.apply(lambda x: x.shrink(((0, 2), (3, 4)))) self.apply(lambda x: x.expand((2, 10))) - def test_double_stride(self): - self.apply(lambda x: x.stride((1, 2))) - self.apply(lambda x: x.stride((2, 1))) + def test_double_flip(self): + self.apply(lambda x: x.flip((True, False))) + self.apply(lambda x: x.flip((True, False))) - def test_stride(self): self.apply(lambda x: x.stride((2,1))) - def test_stride_int(self): self.apply(lambda x: x.stride((1,2))) - def test_stride_2(self): self.apply(lambda x: x.stride((2,2))) - def test_stride_n(self): self.apply(lambda x: x.stride((-2,1))) - def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2))) - def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2))) + def test_flip(self): self.apply(lambda x: x.flip((True, False))) + def test_flip2(self): self.apply(lambda x: x.flip((False, True))) + def test_flip3(self): self.apply(lambda x: x.flip((True, True))) def test_reshape_then_permute(self): self.test_reshape() @@ -838,24 +820,12 @@ class TestShapeTrackerSize(unittest.TestCase): self.assertEqual(st.real_size(), 10) def test_pad_size_multiview(self): - st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,)) + st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)) self.assertEqual(st.real_size(), 100) - # TODO improve real_size accuracy in cases like this? - @unittest.expectedFailure - def test_stride_size(self): - st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,)) - st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,)) - self.assertEqual(st1.real_size(), 78) - self.assertEqual(st2.real_size(), 65) - - def test_stride_size_bounds(self): - # lower bound checks that real_size doesn't give false positive for fitting in a buffer - # upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,)) - st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,)) - st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,)) - self.assertTrue(78 <= st1.real_size() <= 100) - self.assertTrue(65 <= st2.real_size() <= 100) + def test_flip_size(self): + st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True)) + self.assertEqual(st.real_size(), 100) class TestConsecutive(unittest.TestCase): @classmethod diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index ee1f7b76aa..efd017f509 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -13,7 +13,7 @@ class MultiShapeTracker: def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts] def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts] def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts] - def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts] + def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts] def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts] def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool: @@ -75,8 +75,8 @@ class TestShapeTrackerAdd(unittest.TestCase): backup = st.sts[0] st.sts.append(ShapeTracker.from_shape(backup.shape)) st.reshape( (45,) ) - st.stride( (4,) ) - st.reshape( (4, 3) ) + st.flip( (True,) ) + st.reshape( (15, 3) ) assert st_equal(backup + st.sts[1], st.sts[0]) def test_off_by_one(self): @@ -155,22 +155,17 @@ class TestShapeTrackerInvert(unittest.TestCase): def test_can_invert_flip(self): a = ShapeTracker.from_shape((20, 10)) - x = a.stride((-1,1)) + x = a.flip((True,False)) ap = x + x.invert(a.shape) assert st_equal(ap, a) def test_can_invert_flip_permute(self): a = ShapeTracker.from_shape((20, 10)) x = a.permute((1,0)) - x = x.stride((-1,1)) + x = x.flip((True,False)) ap = x + x.invert(a.shape) assert st_equal(ap, a) - def test_cant_invert_stride(self): - a = ShapeTracker.from_shape((10, 10)) - x = a.stride((2,2)) - assert x.invert(a.shape) is None - def test_invert_failure(self): a = ShapeTracker.from_shape((2, 5)) x = a.pad( ((2, 0), (0, 0)) ) diff --git a/tinygrad/engine/multi.py b/tinygrad/engine/multi.py index c6741b0c98..9234247592 100644 --- a/tinygrad/engine/multi.py +++ b/tinygrad/engine/multi.py @@ -125,9 +125,9 @@ def shrink_multi(root:UOp, multi:UOp): return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], axis=multi.axis, real=multi.real) -def stride_multi(root:UOp, multi:UOp): - assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis" - return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) +def flip_multi(root:UOp, multi:UOp): + assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis" + return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) def copy_multi(multi:UOp, device:UOp): # if we already have a copy on the device, return that @@ -155,7 +155,7 @@ multi_pm = PatternMatcher([ (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi), (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi), - (UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi), + (UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi), (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 86c9f0fa63..91d0d3a599 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -33,7 +33,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)), (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)), (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)), - (UPat(Ops.STRIDE, name="ret"), lambda ctx, ret: (ctx.stride(ret.arg) if all(x in {-1,1} for x in ret.arg) else None,)), + (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)), # TODO: this cast can be removed by putting the casts around the EXPAND (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6b1bfcc822..17575e5f1b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -105,7 +105,7 @@ class Ops(FastEnum): BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702 # movement ops! - RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 + RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702 # misc ops UNROLL = auto(); CONTRACT = auto() # noqa: E702 @@ -160,7 +160,7 @@ class GroupOp: ALU = set.union(Unary, Binary, Ternary) Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE} - Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE} + Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP} Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR} Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART} @@ -513,7 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg) def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg) def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg) - def stride(self, arg:tuple[sint, ...]): return self._mop(Ops.STRIDE, arg) + def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg) # *** uop Buffer stuff *** diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 6897a6cc68..2569c22968 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -130,7 +130,7 @@ class ShapeTracker: def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), )) def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), )) def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), )) - def stride(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), )) + def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), )) def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker: if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) @@ -139,4 +139,4 @@ class ShapeTracker: def mop(self, op, arg): return mops[op](self, arg) mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand, - Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad} + Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad} diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 0ba1044ea5..7db63d3fbc 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -212,7 +212,7 @@ class View: def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]: ret = View.create(self.shape) if self.mask: ret = ret.shrink(self.mask) - ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides))) + ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides))) return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none @@ -265,15 +265,10 @@ class View: tuple(self.mask[a] for a in axis) if self.mask is not None else None) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none - def stride(self, mul: tuple[int, ...]) -> View: - # except for the negative case, you can build this from the others. invertible in the negative case - assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" - strides = tuple([z*m for z,m in zip(self.strides, mul)]) - new_shape = tuple([ceildiv(s, abs(m)) for s,m in zip(self.shape, mul)]) - offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0]) - mask = tuple([(ceildiv(mx if m > 0 else s-my, abs(m)), ceildiv(my if m > 0 else s-mx, abs(m))) \ - for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None - return View.create(new_shape, strides, self.offset + offset, mask) + def flip(self, arg: tuple[bool, ...]) -> View: + offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f) + mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None + return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask) @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3f1eed28de..3d06e15bb0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1005,7 +1005,7 @@ class Tensor(SimpleMathTrait): """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") - return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))])) + return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))])) def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor: """