From b64738e1d6e942e3452f47bafd59836cd17df278 Mon Sep 17 00:00:00 2001 From: forcefieldsovereign <146239560+forcefieldsovereign@users.noreply.github.com> Date: Wed, 15 Nov 2023 12:50:17 -0800 Subject: [PATCH] Remove AS_STRIDED from shapetracker (#2216) * very close * remove comment * negative strides working * almost everything passes * calculate offset with list comprehension * some cleanup * got disk load working * review suggestions * fix after merge * overlap working * did it * clean * fixed disk load * lint * mypy * removed as_strided * trying without simplify * added back simplify * make sure expanding to smaller shape * cleanup * removed comment * removed env file * trying whisper test again * onnx test sqlite issue * working on test * finished test * eliminate unnecessary shrink-then-pad * don't shrink buffer * added strides check * added to ci under linters * switch issue * allow symbolic stride * removed .env * isinstance * adjust strides for double expand * cleanup * needed to add type hint for mypy * set pythonpath --- .github/workflows/test.yml | 2 + extra/to_movement_ops.py | 105 +++++++++++++++++++++++++++++++++ tinygrad/ops.py | 2 +- tinygrad/runtime/ops_cpu.py | 1 - tinygrad/runtime/ops_disk.py | 14 ++--- tinygrad/runtime/ops_shm.py | 2 +- tinygrad/runtime/ops_torch.py | 7 --- tinygrad/shape/shapetracker.py | 35 ++++++++--- 8 files changed, 143 insertions(+), 25 deletions(-) create mode 100644 extra/to_movement_ops.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 950e6e6071..4e90f1a8d6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,6 +53,8 @@ jobs: run: python test/external/fuzz_symbolic.py - name: Fuzz Test shapetracker run: PYTHONPATH="." python test/external/fuzz_shapetracker.py + - name: Test shapetracker to_movement_ops + run: PYTHONPATH="." python extra/to_movement_ops.py - name: Use as an external package run: | mkdir $HOME/test_external_dir diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py new file mode 100644 index 0000000000..e82ae25828 --- /dev/null +++ b/extra/to_movement_ops.py @@ -0,0 +1,105 @@ +import random +from tqdm import tqdm +from extra.optimization.helpers import load_worlds +from tinygrad.codegen.linearizer import Linearizer +from tinygrad.ops import LazyOp, MovementOps, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer +from tinygrad.helpers import dtypes, prod +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View +from tinygrad.shape.symbolic import Node, Variable +inf, nan = float('inf'), float('nan') + +def get_real_view(shape, strides, offset, mask): + real_shape = tuple(y-x for x,y in mask) if mask else shape + offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0) + real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0) + real_real_shape = [s for s,st in zip(real_shape, strides) if st] + strides = [abs(st) if isinstance(st,int) else st for st in strides if st] + return real_real_shape, strides, real_offset + +def get_buffer_size(shape, strides, offset, mask): + real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask) + return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1 + +def flatten_view(view: View): + real_real_shape, strides, real_offset = get_real_view(view.shape, view.strides, view.offset, view.mask) + def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) + ordered_shape_strides, _ = sort_by_strides(real_real_shape, strides) + ordered_shape_strides = [list(s) for s in ordered_shape_strides] + if strides: + i = 0 + while i < len(ordered_shape_strides): + if i bool: + return v1 == v2 or flatten_view(v1) == flatten_view(v2) + + +def st_equivalent(st: ShapeTracker, st_rebuilt: ShapeTracker): + views = list(st.views) + rebuilt_views = list(st_rebuilt.views) + i = 0 + while i < len(views): + view, rebuilt_view = views[i], rebuilt_views[i] + if view == rebuilt_view: + i += 1 + continue + elif view.shape == rebuilt_view.shape: + i += 1 + # hack to skip expands for overlapped strides + else: + rebuilt_views.pop(i) + return True + +def test_rebuild(st: ShapeTracker): + rebuilt_st = ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),)) + for mop, arg in st.to_movement_ops(): + if mop == MovementOps.RESHAPE: + # shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE + if arg == (-1,): + rebuilt_st = rebuilt_st.reshape((prod(rebuilt_st.views[-1].shape),)) + else: + rebuilt_st = rebuilt_st.reshape(arg) + elif mop == MovementOps.PERMUTE: + rebuilt_st = rebuilt_st.permute(arg) + elif mop == MovementOps.EXPAND: + if len(arg) != len(rebuilt_st.shape): + rebuilt_st = rebuilt_st.reshape((1,*rebuilt_st.shape)) + rebuilt_st = rebuilt_st.expand(arg) + elif mop == MovementOps.PAD: + rebuilt_st = rebuilt_st.pad(arg) + elif mop == MovementOps.SHRINK: + rebuilt_st = rebuilt_st.shrink(arg) + elif mop == MovementOps.STRIDE: + rebuilt_st = rebuilt_st.stride(arg) + else: + raise Exception("invalid mop") + rebuilt_st = rebuilt_st.simplify() + if len(st.views) != len(rebuilt_st.views): + if not set(st.views).issubset(set(rebuilt_st.views)): + assert st_equivalent(st, rebuilt_st) + else: + for v1,v2 in zip(st.views, rebuilt_st.views): + assert views_equivalent(v1, v2), f"{v1} not equivalent to {v2}" + last_v1 = st.views[-1] + last_v2 = rebuilt_st.views[-1] + assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" + +if __name__ == "__main__": + ast_strs = load_worlds(False, False, True) + random.shuffle(ast_strs) + ast_strs = ast_strs[:2000] + def interpret_ast(ast): + if ast.op in BufferOps: + test_rebuild(ast.arg.st) + else: + for src in ast.src: interpret_ast(src) + for ast_str in tqdm(ast_strs): + ast = eval(ast_str) + interpret_ast(ast) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ad2636cf6a..e715bec423 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -17,7 +17,7 @@ class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702 # Ops below this line are not allowed in ASTs -class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 +class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps] diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 2bc6602449..ca43fa83a6 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -41,7 +41,6 @@ numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(output_type(x, y), copy=False), UnaryOps.SQRT: np.sqrt, MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.PAD: np.pad, MovementOps.EXPAND: np.broadcast_to, MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, i) for i in arg)], - MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(arg[0], buffer=np.require(x, requirements='C'), dtype=x.dtype, offset=arg[2]*x.dtype.itemsize, strides=tuple(y*x.dtype.itemsize for y in arg[1])), TernaryOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True), lambda x: x.strides, np.broadcast_to), TernaryOps.WHERE: np.where, }} diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 4bff4e0ad6..40fc8ae826 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -24,18 +24,16 @@ class RawDiskBuffer(RawBufferMapped): def cast(self, arg:Tuple[DType, bool]): return RawDiskBuffer(self.size, arg[0], buf=self._buf, shape=self.shape, offset=self.offset) def reshape(self, arg): return RawDiskBuffer(self.size, self.dtype, buf=self._buf, shape=arg, offset=self.offset) def shrink(self, arg): - assert arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}" - offset = arg[0][0]*prod(self.shape[1:])*self.dtype.itemsize - size = (arg[0][1]-arg[0][0]) * prod(self.shape[1:]) - return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+self.shape[1:]) - - def as_strided(self, arg): - return RawDiskBuffer(prod(arg[0]), self.dtype, buf=self._buf, offset=self.offset+arg[2]*self.dtype.itemsize, shape=arg[0]) + assert len(arg)<2 or arg[1:] == tuple([(0,x) for x in self.shape[1:]]), f"can only slice the first dim of disk tensor {arg}" + offset = arg[0][0]*(prod(self.shape[1:]) if len(arg)>1 else 1)*self.dtype.itemsize + size = (arg[0][1]-arg[0][0]) * (prod(self.shape[1:]) if len(arg)>1 else 1) + return RawDiskBuffer(size, self.dtype, buf=self._buf, offset=self.offset+offset, shape=(arg[0][1]-arg[0][0],)+(self.shape[1:] if len(arg)>1 else ())) def _buffer(self): return memoryview(self._buf[1])[self.offset:self.offset+self.size*self.dtype.itemsize] def readinto(self, buf): self._buf[0].seek(self.offset) self._buf[0].readinto(buf) -disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided } +disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.SHRINK: RawDiskBuffer.shrink, MovementOps.RESHAPE: RawDiskBuffer.reshape } DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op) + diff --git a/tinygrad/runtime/ops_shm.py b/tinygrad/runtime/ops_shm.py index 8e5419b7a0..f72502d6fa 100644 --- a/tinygrad/runtime/ops_shm.py +++ b/tinygrad/runtime/ops_shm.py @@ -24,5 +24,5 @@ class RawShmBuffer(RawBufferMapped): def _buffer(self): return memoryview(self._buf) # TODO: is this wrong? -shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x } +shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x } ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index e489ba02d9..b925aeca42 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -16,12 +16,6 @@ def match_types(x, y, disallow_bool=False): if disallow_bool and up == torch.bool: up = torch.float return x.type(up), y.type(up) -def as_strided(x, arg): - if any(i < 0 for i in arg[1]): - return torch.as_strided(x.contiguous(), arg[0], tuple(abs(i) for i in arg[1]), - arg[2] + sum((s-1)*a if a < 0 else 0 for (s,a) in zip(arg[0], arg[1]))).flip([i for i,a in enumerate(arg[1]) if a < 0]) - return torch.as_strided(x.contiguous(), arg[0], arg[1], arg[2]) - torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ # TODO: torch.tensor should work here #BufferOps.CONST: lambda val, dtype: torch.tensor(val, device=device, dtype=inverse_type_map[dtype]), @@ -38,7 +32,6 @@ torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ TernaryOps.WHERE: lambda x, y, z: torch.where(x != 0, y, z), MovementOps.STRIDE: lambda x, arg: x[tuple(slice(None, None, abs(i)) for i in arg)].flip([i for i,a in enumerate(arg) if a < 0]), MovementOps.EXPAND: lambda x, arg: x.expand(arg), MovementOps.PERMUTE: lambda x, arg: x.permute(arg), - MovementOps.AS_STRIDED: as_strided }} class RawTorchBuffer(RawBuffer): diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 71346e3815..d78a092f04 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -97,14 +97,35 @@ class ShapeTracker: def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: to_apply:List[Tuple[MovementOps, Tuple]] = [] - for v in self.views: + for i, v in enumerate(self.views): real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) - # first, we apply the offset - # then, we make it the correct shape - # then, we apply permutations - # TODO: don't use as_strided - to_apply.append((MovementOps.AS_STRIDED, (tuple([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)]), v.strides, real_offset))) + offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0) + real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) + real_real_shape = [s for s,st in zip(real_shape, v.strides) if st] + strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st] + buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1 + if i: buffer_size = prod(self.views[i-1].shape) - real_offset + def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True) + ordered_shape_strides, order = sort_by_strides(real_real_shape, strides) + to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))]) + if strides: + if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),))) + for i, shape_stride in enumerate(ordered_shape_strides): + if i0 else buffer_size + to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer))) + to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1))) + to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer))) + to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1])))) + to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1]))) + ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1]) + else: + to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1])))) + to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1]))) + to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))]) + if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides))))) + to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides)))) + if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides))) # then, we apply pre expand pads if v.mask is not None: pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))