From 727416201fc05e81db7ef5af8d9af43e35c80f65 Mon Sep 17 00:00:00 2001 From: Rayan Hatout Date: Tue, 13 Jun 2023 02:13:21 +0100 Subject: [PATCH] Shapetracker optimizations (#966) * optimizations in shapetracker.py * revert micro-optimizations in assertions * make mypy happy * list comp instead of map in get_unsafe_resize_offset * list comp instead of map in get_unsafe_resize_offset --- tinygrad/shape/shapetracker.py | 165 ++++++++++++++++++++------------- 1 file changed, 101 insertions(+), 64 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 19a53a3158..4831e0debb 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,8 +1,8 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations -import functools from enum import Enum, auto -from typing import Tuple, Union, List, Optional, Dict, Callable +import functools +from typing import Dict, Tuple, Union, List, Optional, Callable, cast from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode @@ -11,8 +11,8 @@ class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PA def check_no_mul(test, var): if test == var: return True - if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay - if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay + if test.__class__ is SumNode: return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay + if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay return False @functools.lru_cache(maxsize=None) @@ -27,22 +27,33 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup return ret @functools.lru_cache(maxsize=None) -def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))) +def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all([s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))]) + +@functools.lru_cache(maxsize=None) +def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: + new_strides = [] + for stride, shp in zip(strides, shape): + if shp != 1: new_strides.append(stride) + else: new_strides.append(0) + return tuple(new_strides) class View: + __slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous" def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None): - self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset + self.shape, self.offset = shape, offset + self.strides = filter_strides(shape, strides) self.mask = mask - self.shape_strides = to_shape_strides(self.shape, self.strides) - self.contiguous: bool = self.offset == 0 and is_contiguous(self.shape, self.strides) and mask is None + self.shape_strides = to_shape_strides(shape, self.strides) + self.contiguous: bool = offset == 0 and is_contiguous(shape, self.strides) and mask is None def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})" + def key(self): return (self.shape, self.strides, self.offset, self.mask) def expr_node_mask(self, idx, valid=None) -> Node: expr = [valid] if valid is not None else [] if self.mask is not None: acc = 1 - for ns,(x,y) in list(zip(self.shape, self.mask))[::-1]: + for ns,(x,y) in reversed(list(zip(self.shape, self.mask))): base = ((idx//acc) % ns) expr += [base >= x, base < y] acc *= ns @@ -52,7 +63,7 @@ class View: assert len(idxs) == len(self.shape), "need an idx for all dimensions" acc = 1 ret = [] - for tidx,d in list(zip(idxs, self.shape))[::-1]: + for tidx,d in reversed(list(zip(idxs, self.shape))): ret.append(tidx * acc) acc *= d return Variable.sum(ret) @@ -62,7 +73,7 @@ class View: if idx is None: idx = Variable('idx', 0, prod(self.shape)) ret = [Variable.num(self.offset)] acc = 1 - for d,s in self.shape_strides[::-1]: + for d,s in reversed(self.shape_strides): ret.append(((idx//acc)%d)*s) acc *= d return Variable.sum(ret) @@ -76,7 +87,7 @@ class View: def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: strides = [1] if shape else [] for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides - return tuple(st if s != 1 else 0 for st, s in zip(strides, shape)) + return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)]) @functools.lru_cache(maxsize=None) def view_from_shape(shape:Tuple[int, ...]) -> View: @@ -92,35 +103,74 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st) if s == 1: new_strides.append(0) # all shape 1 can have stride 0 - elif isinstance(this_dim, NumNode) and this_dim.b == 0: + elif this_dim.__class__ == NumNode and this_dim.b == 0: new_strides.append(0) - elif isinstance(this_dim, Variable): + elif this_dim.__class__ == Variable: new_strides.append(1) - elif isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable): + elif this_dim.__class__ == MulNode and cast(MulNode, this_dim).a.__class__ == Variable: new_strides.append(this_dim.b) else: if DEBUG >= 4: print("can't simplify", s, this_dim.render()) break return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None +@functools.lru_cache(maxsize=None) +def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]: + shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset + # check if this is adding or removing 1s (only) + # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) + if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]: + new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1] + new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape]) + new_mask_tuple = None + if mask: + for x,y in zip(shape, mask): + if x == 1 and y != (0, 1): + new_mask_tuple = tuple([(0,0) for _ in new_shape]) + break + else: + new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1] + new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape]) + return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False + + new_view = View(new_shape, strides_for_shape(new_shape)) + if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset + else: + if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False + else: + if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") + return new_view, True + +@functools.lru_cache(maxsize=None) +def get_pad_args(shape, arg: Tuple[Tuple[int, int], ...]): + return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)]) + +@functools.lru_cache(maxsize=None) +def get_unsafe_resize_offset(strides, arg): + return sum([s * x[0] for s, x in zip(strides,arg)]) + class ShapeTracker: + __slots__ = "views" def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): - self.views: List[View] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)]) - def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})" - def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:]) + self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ == ShapeTracker else [view_from_shape(shape)]) + def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" + def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) @property - def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous + def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous @property def shape(self) -> Tuple[int, ...]: return self.views[-1].shape + @property + def key(self) -> Tuple[int, ...]: return tuple(map(View.key, self.views)) + # this is the real size (ish) - def size(self): return prod([s for s,st in zip(self.shape, self.views[-1].strides) if st != 0]) + def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) def unit_stride_axes(self) -> List[int]: ret, acc = [], 1 - for j,s in list(enumerate(self.shape))[::-1]: + for j,s in reversed(list(enumerate(self.shape))): if s == 1: continue var = Variable('idx', 0, s-1) this_dim = self.expr_node(var*acc) @@ -129,7 +179,7 @@ class ShapeTracker: return ret def _expr_idx(self, idx, valid): - for v in self.views[0:-1][::-1]: + for v in reversed(self.views[0:-1]): valid = v.expr_node_mask(idx, valid) idx = v.expr_node(idx) return idx, valid @@ -149,85 +199,72 @@ class ShapeTracker: return self._expr_idx(idx, valid) def expr_node(self, idx='idx'): - if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1) + if idx.__class__ == str: idx = Variable(idx, 0, prod(self.shape)-1) return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx)) def needs_valid(self) -> bool: - return any(v.mask is not None for v in self.views) + return any([v.mask is not None for v in self.views]) # *** under this line are the movement ops *** def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None): - offset = sum([self.views[-1].strides[i]*x for i,(x,_) in enumerate(arg)]) + offset = get_unsafe_resize_offset(self.views[-1].strides, arg) if self.views[-1].mask: # move the old mask - nmask = tuple((max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)) + nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)]) # merge the masks if we have two - mask = tuple((max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)) if mask is not None else nmask - self.views[-1] = View(tuple(y-x for x,y in arg), self.views[-1].strides, self.views[-1].offset+offset, mask) + mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask + self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask) def pad(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape) - if all(b==0 and e==0 for b,e in arg): return self - zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg)) - self.__unsafe_resize(zvarg, mask=tuple((b,s+b) for s,(b,_) in zip(self.shape, arg))) + if any([b or e for b, e in arg]): + zvarg, mask = get_pad_args(self.shape, arg) + self.__unsafe_resize(zvarg, mask=mask) + return self + return self def shrink(self, arg: Tuple[Tuple[int, int], ...]): assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape) self.__unsafe_resize(arg) + return self - def expand(self, new_shape: Tuple[int, ...]): + def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker: + assert len(new_shape) == len(self.views[-1].shape) assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}" # NOTE: can the mask ever be (0,0)? - mask = tuple((((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)) if self.views[-1].mask else None + mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask) + return self def reshape(self, new_shape: Tuple[int, ...]): - if self.shape == new_shape: return self + if self.views[-1].shape == new_shape: return self assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}" assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" - - # check if this is adding or removing 1s (only) - # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) - if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1): - old_strides = [y for x,y in zip(self.shape, self.views[-1].strides) if x != 1] - new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape) - new_mask_tuple = None - if self.views[-1].mask: - if any(y!=(0,1) for x,y in zip(self.shape, self.views[-1].mask) if x == 1): - # mask it all out! - new_mask_tuple = tuple((0,0) for _ in new_shape) - else: - old_mask = [y for x,y in zip(self.shape, self.views[-1].mask) if x != 1] - new_mask_tuple = tuple((0,1) if x == 1 else old_mask.pop(0) for x in new_shape) - self.views[-1] = View(new_shape, new_strides_tuple, self.views[-1].offset, new_mask_tuple) - return self - - view = View(new_shape, strides_for_shape(new_shape)) - if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset - else: - if (merged_view := merge_views(self.views[-1], view)) is not None: self.views[-1] = merged_view - else: - if DEBUG >= 4: print(f"WARNING: creating new view with reshape {self} -> {new_shape}") - self.views.append(view) + new_view, extra = _reshape(self.views[-1], new_shape) + if extra: self.views.append(new_view) + else: self.views[-1] = new_view + return self def permute(self, axis: Tuple[int, ...]): assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" - self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.views[-1].strides[a] for a in axis), self.views[-1].offset, tuple(self.views[-1].mask[a] for a in axis) if self.views[-1].mask is not None else None) + self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None) + return self # except for the negative case, you can build this from the others. invertible in the negative case def stride(self, mul: Tuple[int, ...]): assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" - strides = tuple(z*m for z,m in zip(self.views[-1].strides, mul)) - new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul)) - offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.views[-1].strides, mul) if m < 0]) - mask = tuple((((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.shape, mul)) if self.views[-1].mask is not None else None + strides = tuple([z*m for z,m in zip(self.views[-1].strides, mul)]) + new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.views[-1].shape, mul)]) + offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0]) + mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask) + return self # *** entry point for external *** - def movement_op(self, op, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker: + def movement_op(self, op: MovementOps, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker: assert isinstance(arg, tuple) and (len(arg) == len(self.shape) or op == MovementOps.RESHAPE), f"arg {arg} for {op} doesn't match dim of shape {self.shape}" dispatch[op](self, arg) return self