mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
[bounty] Stride is flip (#8792)
* replace stride with flip * Complete replacing stride with flip clean flip function in view.py fix tests * fix tests for multi shapetracker * fix tests for fuzz shapetracker * fix tests for fuzz shapetracker * debug * debug * fix * fix * fix --------- Co-authored-by: George Hotz <geohot@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
15
test/external/fuzz_shapetracker.py
vendored
15
test/external/fuzz_shapetracker.py
vendored
@@ -38,17 +38,10 @@ def do_shrink(st):
|
||||
if DEBUG >= 1: print("st.shrink(", shrink, ")")
|
||||
st.shrink(shrink)
|
||||
|
||||
def do_stride(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
|
||||
if DEBUG >= 1: print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
|
||||
def do_flip(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
|
||||
if DEBUG >= 1: print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
flip = tuple(random.random() < 0.5 for _ in st.shape)
|
||||
if DEBUG >= 1: print("st.flip(", flip, ")")
|
||||
st.flip(flip)
|
||||
|
||||
def do_expand(st):
|
||||
c = [i for i,s in enumerate(st.shape) if s==1]
|
||||
@@ -58,7 +51,7 @@ def do_expand(st):
|
||||
if DEBUG >= 1: print("st.expand(", expand, ")")
|
||||
st.expand(expand)
|
||||
|
||||
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
|
||||
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_flip, do_expand]
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(42)
|
||||
|
||||
@@ -43,9 +43,9 @@ class CheckingShapeTracker:
|
||||
self.t = np.broadcast_to(self.t, new_shape)
|
||||
return self
|
||||
|
||||
def flip(self, axis):
|
||||
self.st = self.st.stride(tuple(-1 if i in axis else 1 for i in range(len(self.shape))))
|
||||
self.t = np.flip(self.t, axis)
|
||||
def flip(self, arg):
|
||||
self.st = self.st.flip(arg)
|
||||
self.t = np.flip(self.t, tuple(i for i in range(len(arg)) if arg[i]))
|
||||
return self
|
||||
|
||||
def shrink(self, arg):
|
||||
@@ -58,11 +58,6 @@ class CheckingShapeTracker:
|
||||
self.t = np.pad(self.t, arg, constant_values=-1)
|
||||
return self
|
||||
|
||||
def stride(self, arg):
|
||||
self.st = self.st.stride(arg)
|
||||
self.t = self.t[tuple([slice(None, None, x) for x in arg])]
|
||||
return self
|
||||
|
||||
def __getitem__(self, val):
|
||||
return self.t.flatten()[val]
|
||||
|
||||
@@ -579,20 +574,18 @@ class TestShapeTrackerFuzzFailures(unittest.TestCase):
|
||||
self.st.shrink(((1, 2), (1, 3), (1, 3)))
|
||||
self.st.reshape((1, 4))
|
||||
self.st.shrink(((0, 1), (1, 3)))
|
||||
print(self.st.st)
|
||||
self.st = self.st.simplify()
|
||||
print(self.st.st)
|
||||
def test_case_2(self):
|
||||
self.st.stride( (1, 1, -2) )
|
||||
self.st.reshape( (3, 6) )
|
||||
self.st.flip( (True, False, True) )
|
||||
self.st.reshape( (3, 9) )
|
||||
self.st.shrink( ((1, 2), (1, 5)) )
|
||||
self.st.stride( (1, -1) )
|
||||
self.st.flip( (True, True) )
|
||||
def test_case_3(self):
|
||||
self.st.shrink( ((0, 2), (0, 2), (0, 1)) )
|
||||
self.st.permute( (1, 0, 2) )
|
||||
self.st.reshape( (4,) )
|
||||
self.st.shrink( ((0, 3),) )
|
||||
self.st.stride( (-1,) )
|
||||
self.st.flip( (True, False) )
|
||||
def test_case_4(self):
|
||||
self.st.reshape( (3, 3, 3, 1) )
|
||||
self.st.pad( ((0, 0), (0, 0), (0, 0), (1, 1)) )
|
||||
@@ -687,21 +680,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.st.reshape((9,6,1))
|
||||
self.st.expand((9,6,4))
|
||||
|
||||
def test_pad_stride(self):
|
||||
def test_pad_flip(self):
|
||||
self.st.pad(((1,4), (1,3)))
|
||||
self.st.stride((2,2))
|
||||
self.st.flip((True, False))
|
||||
|
||||
def test_pad_stride_neg(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-1,-1))
|
||||
|
||||
def test_pad_stride_both(self):
|
||||
self.st.pad(((1,2), (1,0)))
|
||||
self.st.stride((-2,-2))
|
||||
|
||||
def test_shrink_pad(self):
|
||||
self.st.shrink(((0,4), (0,4)))
|
||||
self.st.pad(((1,1), (1,1)))
|
||||
def test_pad_flip_int(self):
|
||||
self.st.pad(((1,4), (1,3)))
|
||||
self.st.flip((0, 1))
|
||||
|
||||
def test_reshape(self):
|
||||
new_shape = self.st.shape[::-1]
|
||||
@@ -722,13 +707,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.apply(lambda x: x.expand(tuple(new_shape)))
|
||||
|
||||
def test_flip_0(self):
|
||||
self.apply(lambda x: x.flip((0,)))
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
|
||||
def test_flip_1(self):
|
||||
self.apply(lambda x: x.flip((1,)))
|
||||
self.apply(lambda x: x.flip((False, True)))
|
||||
|
||||
def test_flip_01(self):
|
||||
self.apply(lambda x: x.flip((0,1)))
|
||||
self.apply(lambda x: x.flip((True, True)))
|
||||
|
||||
def test_slice_0(self):
|
||||
self.apply(lambda x: x.shrink(((1, x.shape[0]), (0, x.shape[1]))))
|
||||
@@ -754,16 +739,13 @@ class TestShapeTracker(unittest.TestCase):
|
||||
self.apply(lambda x: x.shrink(((0, 2), (3, 4))))
|
||||
self.apply(lambda x: x.expand((2, 10)))
|
||||
|
||||
def test_double_stride(self):
|
||||
self.apply(lambda x: x.stride((1, 2)))
|
||||
self.apply(lambda x: x.stride((2, 1)))
|
||||
def test_double_flip(self):
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
self.apply(lambda x: x.flip((True, False)))
|
||||
|
||||
def test_stride(self): self.apply(lambda x: x.stride((2,1)))
|
||||
def test_stride_int(self): self.apply(lambda x: x.stride((1,2)))
|
||||
def test_stride_2(self): self.apply(lambda x: x.stride((2,2)))
|
||||
def test_stride_n(self): self.apply(lambda x: x.stride((-2,1)))
|
||||
def test_stride_int_n(self): self.apply(lambda x: x.stride((-1,2)))
|
||||
def test_stride_2_n(self): self.apply(lambda x: x.stride((-2,-2)))
|
||||
def test_flip(self): self.apply(lambda x: x.flip((True, False)))
|
||||
def test_flip2(self): self.apply(lambda x: x.flip((False, True)))
|
||||
def test_flip3(self): self.apply(lambda x: x.flip((True, True)))
|
||||
|
||||
def test_reshape_then_permute(self):
|
||||
self.test_reshape()
|
||||
@@ -838,24 +820,12 @@ class TestShapeTrackerSize(unittest.TestCase):
|
||||
self.assertEqual(st.real_size(), 10)
|
||||
|
||||
def test_pad_size_multiview(self):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
# TODO improve real_size accuracy in cases like this?
|
||||
@unittest.expectedFailure
|
||||
def test_stride_size(self):
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertEqual(st1.real_size(), 78)
|
||||
self.assertEqual(st2.real_size(), 65)
|
||||
|
||||
def test_stride_size_bounds(self):
|
||||
# lower bound checks that real_size doesn't give false positive for fitting in a buffer
|
||||
# upper bound checks that real_size doesn't exceed N when movementops were applied to from_shape((N,))
|
||||
st1 = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).reshape((16*14,)).stride((17,))
|
||||
st2 = ShapeTracker.from_shape((10,10)).stride((2,1)).reshape((5*10,)).stride((17,))
|
||||
self.assertTrue(78 <= st1.real_size() <= 100)
|
||||
self.assertTrue(65 <= st2.real_size() <= 100)
|
||||
def test_flip_size(self):
|
||||
st = ShapeTracker.from_shape((10,10)).pad(((2,4), (3,1))).flip((True, True))
|
||||
self.assertEqual(st.real_size(), 100)
|
||||
|
||||
class TestConsecutive(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ class MultiShapeTracker:
|
||||
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
||||
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
||||
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
||||
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
||||
def flip(self, arg): self.sts = [x.flip(arg) for x in self.sts]
|
||||
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
||||
|
||||
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
|
||||
@@ -75,8 +75,8 @@ class TestShapeTrackerAdd(unittest.TestCase):
|
||||
backup = st.sts[0]
|
||||
st.sts.append(ShapeTracker.from_shape(backup.shape))
|
||||
st.reshape( (45,) )
|
||||
st.stride( (4,) )
|
||||
st.reshape( (4, 3) )
|
||||
st.flip( (True,) )
|
||||
st.reshape( (15, 3) )
|
||||
assert st_equal(backup + st.sts[1], st.sts[0])
|
||||
|
||||
def test_off_by_one(self):
|
||||
@@ -155,22 +155,17 @@ class TestShapeTrackerInvert(unittest.TestCase):
|
||||
|
||||
def test_can_invert_flip(self):
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.stride((-1,1))
|
||||
x = a.flip((True,False))
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_can_invert_flip_permute(self):
|
||||
a = ShapeTracker.from_shape((20, 10))
|
||||
x = a.permute((1,0))
|
||||
x = x.stride((-1,1))
|
||||
x = x.flip((True,False))
|
||||
ap = x + x.invert(a.shape)
|
||||
assert st_equal(ap, a)
|
||||
|
||||
def test_cant_invert_stride(self):
|
||||
a = ShapeTracker.from_shape((10, 10))
|
||||
x = a.stride((2,2))
|
||||
assert x.invert(a.shape) is None
|
||||
|
||||
def test_invert_failure(self):
|
||||
a = ShapeTracker.from_shape((2, 5))
|
||||
x = a.pad( ((2, 0), (0, 0)) )
|
||||
|
||||
@@ -125,9 +125,9 @@ def shrink_multi(root:UOp, multi:UOp):
|
||||
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
|
||||
axis=multi.axis, real=multi.real)
|
||||
|
||||
def stride_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis"
|
||||
return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
def flip_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or not root.arg[multi.axis], "flipping not supported on sharded axis"
|
||||
return UOp.multi(*[x.flip(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
def copy_multi(multi:UOp, device:UOp):
|
||||
# if we already have a copy on the device, return that
|
||||
@@ -155,7 +155,7 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
||||
(UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi),
|
||||
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
|
||||
@@ -33,7 +33,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||||
(UPat(Ops.STRIDE, name="ret"), lambda ctx, ret: (ctx.stride(ret.arg) if all(x in {-1,1} for x in ret.arg) else None,)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.arg),)),
|
||||
# TODO: this cast can be removed by putting the casts around the EXPAND
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
||||
|
||||
@@ -105,7 +105,7 @@ class Ops(FastEnum):
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702
|
||||
|
||||
# movement ops!
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
|
||||
# misc ops
|
||||
UNROLL = auto(); CONTRACT = auto() # noqa: E702
|
||||
@@ -160,7 +160,7 @@ class GroupOp:
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.STRIDE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
Buffer = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
||||
@@ -513,7 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
||||
def permute(self, arg:tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
||||
def stride(self, arg:tuple[sint, ...]): return self._mop(Ops.STRIDE, arg)
|
||||
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg)
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class ShapeTracker:
|
||||
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
|
||||
def expand(self, new_shape: tuple[sint, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape), ))
|
||||
def permute(self, axis: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis), ))
|
||||
def stride(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul), ))
|
||||
def flip(self, mul: tuple[int, ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].flip(mul), ))
|
||||
|
||||
def reshape(self, new_shape: tuple[sint, ...]) -> ShapeTracker:
|
||||
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
||||
@@ -139,4 +139,4 @@ class ShapeTracker:
|
||||
def mop(self, op, arg): return mops[op](self, arg)
|
||||
|
||||
mops: dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
|
||||
Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad}
|
||||
Ops.SHRINK: ShapeTracker.shrink, Ops.FLIP: ShapeTracker.flip, Ops.PAD: ShapeTracker.pad}
|
||||
|
||||
@@ -212,7 +212,7 @@ class View:
|
||||
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
|
||||
ret = View.create(self.shape)
|
||||
if self.mask: ret = ret.shrink(self.mask)
|
||||
ret = ret.stride(tuple(-1 if x < 0 else 1 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
||||
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
||||
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@@ -265,15 +265,10 @@ class View:
|
||||
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def stride(self, mul: tuple[int, ...]) -> View:
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
||||
strides = tuple([z*m for z,m in zip(self.strides, mul)])
|
||||
new_shape = tuple([ceildiv(s, abs(m)) for s,m in zip(self.shape, mul)])
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.strides, mul) if m < 0])
|
||||
mask = tuple([(ceildiv(mx if m > 0 else s-my, abs(m)), ceildiv(my if m > 0 else s-mx, abs(m))) \
|
||||
for (mx,my),s,m in zip(self.mask, self.shape, mul)]) if self.mask is not None else None
|
||||
return View.create(new_shape, strides, self.offset + offset, mask)
|
||||
def flip(self, arg: tuple[bool, ...]) -> View:
|
||||
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
|
||||
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
|
||||
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
|
||||
|
||||
@@ -1005,7 +1005,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
|
||||
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
||||
|
||||
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user