[bounty] Stride is flip (#8792)

* replace stride with flip

* Complete replacing stride with flip

clean flip function in view.py
fix tests

* fix tests for multi shapetracker

* fix tests for fuzz shapetracker

* fix tests for fuzz shapetracker

* debug

* debug

* fix

* fix

* fix

---------

Co-authored-by: George Hotz <geohot@gmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ankit Avinash
2025-01-31 08:04:10 +05:30
committed by GitHub
parent 0513b0c17d
commit 7647cd8428
9 changed files with 50 additions and 97 deletions

View File

@@ -38,17 +38,10 @@ def do_shrink(st):
if DEBUG >= 1: print("st.shrink(", shrink, ")")
st.shrink(shrink)
def do_stride(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
if DEBUG >= 1: print("st.stride(", stride, ")")
st.stride(stride)
def do_flip(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
if DEBUG >= 1: print("st.stride(", stride, ")")
st.stride(stride)
flip = tuple(random.random() < 0.5 for _ in st.shape)
if DEBUG >= 1: print("st.flip(", flip, ")")
st.flip(flip)
def do_expand(st):
c = [i for i,s in enumerate(st.shape) if s==1]
@@ -58,7 +51,7 @@ def do_expand(st):
if DEBUG >= 1: print("st.expand(", expand, ")")
st.expand(expand)
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_flip, do_expand]
if __name__ == "__main__":
random.seed(42)

View File

@@ -43,9 +43,9 @@ class CheckingShapeTracker:
self.t = np.broadcast_to(self.t, new_shape)
return self
def flip(self, axis):
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
self.t = np.flip(self.t, axis)
def flip(self, arg):
self.st = self.st.flip(arg)
self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i]))
return self
def shrink(self, arg):
@@ -58,11 +58,6 @@ class CheckingShapeTracker:
self.t = np.pad(self.t, arg, constant_values=-1)
return self
def stride(self, arg):
self.st = self.st.stride(arg)
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
return self
def __getitem__(self, val):
return self.t.flatten()[val]
@@ -579,20 +574,18 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
self.st.shrink(((1, 2), (1, 3), (1, 3)))
self.st.reshape((1, 4))
self.st.shrink(((0, 1), (1, 3)))
print(self.st.st)
self.st = self.st.simplify()
print(self.st.st)
def test_case_2(self):
self.st.stride( (1, 1, -2) )
self.st.reshape( (3, 6) )
self.st.flip( (True, False, True) )
self.st.reshape( (3, 9) )
self.st.shrink( ((1, 2), (1, 5)) )
self.st.stride( (1, -1) )
self.st.flip( (True, True) )
def test_case_3(self):
self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
self.st.permute( (1, 0, 2) )
self.st.reshape( (4,) )
self.st.shrink( ((0, 3),) )
self.st.stride( (-1,) )
self.st.flip( (True, False) )
def test_case_4(self):
self.st.reshape( (3, 3, 3, 1) )
self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
@@ -687,21 +680,13 @@ class TestShapeTracker(unittest.TestCase):
self.st.reshape((9,6,1))
self.st.expand((9,6,4))
def test_pad_stride(self):
def test_pad_flip(self):
self.st.pad(((1,4), (1,3)))
self.st.stride((2,2))
self.st.flip((True, False))
def test_pad_stride_neg(self):
self.st.pad(((1,2), (1,0)))
self.st.stride((-1,-1))
def test_pad_stride_both(self):
self.st.pad(((1,2), (1,0)))
self.st.stride((-2,-2))
def test_shrink_pad(self):
self.st.shrink(((0,4), (0,4)))
self.st.pad(((1,1), (1,1)))
def test_pad_flip_int(self):
self.st.pad(((1,4), (1,3)))
self.st.flip((0, 1))
def test_reshape(self):
new_shape = self.st.shape[::-1]
@@ -722,13 +707,13 @@ class TestShapeTracker(unittest.TestCase):
self.apply(lambda x: x.expand(tuple(new_shape)))
def test_flip_0(self):
self.apply(lambda x: x.flip((0,)))
self.apply(lambda x: x.flip((True, False)))
def test_flip_1(self):
self.apply(lambda x: x.flip((1,)))
self.apply(lambda x: x.flip((False, True)))
def test_flip_01(self):
self.apply(lambda x: x.flip((0,1)))
self.apply(lambda x: x.flip((True, True)))
def test_slice_0(self):
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
@@ -754,16 +739,13 @@ class TestShapeTracker(unittest.TestCase):
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
self.apply(lambda x: x.expand((2, 10)))
def test_double_stride(self):
self.apply(lambda x: x.stride((1, 2)))
self.apply(lambda x: x.stride((2, 1)))
def test_double_flip(self):
self.apply(lambda x: x.flip((True, False)))
self.apply(lambda x: x.flip((True, False)))
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
def test_flip(self): self.apply(lambda x: x.flip((True, False)))
def test_flip2(self): self.apply(lambda x: x.flip((False, True)))
def test_flip3(self): self.apply(lambda x: x.flip((True, True)))
def test_reshape_then_permute(self):
self.test_reshape()
@@ -838,24 +820,12 @@ class TestShapeTrackerSize(unittest.TestCase):
self.assertEqual(st.real_size(), 10)
def test_pad_size_multiview(self):
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,))
self.assertEqual(st.real_size(), 100)
# TODO improve real_size accuracy in cases like this?
@unittest.expectedFailure
def test_stride_size(self):
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
self.assertEqual(st1.real_size(), 78)
self.assertEqual(st2.real_size(), 65)
def test_stride_size_bounds(self):
# lower bound checks that real_size doesn't give false positive for fitting in a buffer
# upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,))
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
self.assertTrue(78 <= st1.real_size() <= 100)
self.assertTrue(65 <= st2.real_size() <= 100)
def test_flip_size(self):
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
self.assertEqual(st.real_size(), 100)
class TestConsecutive(unittest.TestCase):
@classmethod

View File

@@ -13,7 +13,7 @@ class MultiShapeTracker:
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts]
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
@@ -75,8 +75,8 @@ class TestShapeTrackerAdd(unittest.TestCase):
backup = st.sts[0]
st.sts.append(ShapeTracker.from_shape(backup.shape))
st.reshape( (45,) )
st.stride( (4,) )
st.reshape( (4, 3) )
st.flip( (True,) )
st.reshape( (15, 3) )
assert st_equal(backup + st.sts[1], st.sts[0])
def test_off_by_one(self):
@@ -155,22 +155,17 @@ class TestShapeTrackerInvert(unittest.TestCase):
def test_can_invert_flip(self):
a = ShapeTracker.from_shape((20, 10))
x = a.stride((-1,1))
x = a.flip((True,False))
ap = x + x.invert(a.shape)
assert st_equal(ap, a)
def test_can_invert_flip_permute(self):
a = ShapeTracker.from_shape((20, 10))
x = a.permute((1,0))
x = x.stride((-1,1))
x = x.flip((True,False))
ap = x + x.invert(a.shape)
assert st_equal(ap, a)
def test_cant_invert_stride(self):
a = ShapeTracker.from_shape((10, 10))
x = a.stride((2,2))
assert x.invert(a.shape) is None
def test_invert_failure(self):
a = ShapeTracker.from_shape((2, 5))
x = a.pad( ((2, 0), (0, 0)) )

View File

@@ -125,9 +125,9 @@ def shrink_multi(root:UOp, multi:UOp):
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
axis=multi.axis, real=multi.real)
def stride_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis"
return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
def flip_multi(root:UOp, multi:UOp):
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
def copy_multi(multi:UOp, device:UOp):
# if we already have a copy on the device, return that
@@ -155,7 +155,7 @@ multi_pm = PatternMatcher([
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi),
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),

View File

@@ -33,7 +33,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
(UPat(Ops.STRIDE, name="ret"), lambda ctx, ret: (ctx.stride(ret.arg) if all(x in {-1,1} for x in ret.arg) else None,)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
# TODO: this cast can be removed by putting the casts around the EXPAND
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),

View File

@@ -105,7 +105,7 @@ class Ops(FastEnum):
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
# movement ops!
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
# misc ops
UNROLL = auto(); CONTRACT = auto() # noqa: E702
@@ -160,7 +160,7 @@ class GroupOp:
ALU = set.union(Unary, Binary, Ternary)
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
@@ -513,7 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
def stride(self, arg:tuple[sint, ...]): return self._mop(Ops.STRIDE, arg)
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
# *** uop Buffer stuff ***

View File

@@ -130,7 +130,7 @@ class ShapeTracker:
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
def stride(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
@@ -139,4 +139,4 @@ class ShapeTracker:
def mop(self, op, arg): return mops[op](self, arg)
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad}
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}

View File

@@ -212,7 +212,7 @@ class View:
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
ret = View.create(self.shape)
if self.mask: ret = ret.shrink(self.mask)
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@@ -265,15 +265,10 @@ class View:
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def stride(self, mul: tuple[int, ...]) -> View:
# except for the negative case, you can build this from the others. invertible in the negative case
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
strides = tuple([z*m for z,m in zip(self.strides, mul)])
new_shape = tuple([ceildiv(s, abs(m)) for s,m in zip(self.shape, mul)])
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
mask = tuple([(ceildiv(mx if m > 0 else s-my, abs(m)), ceildiv(my if m > 0 else s-mx, abs(m))) \
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
return View.create(new_shape, strides, self.offset + offset, mask)
def flip(self, arg: tuple[bool, ...]) -> View:
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:

View File

@@ -1005,7 +1005,7 @@ class Tensor(SimpleMathTrait):
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
"""