mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Shapetracker optimizations (#966)
* optimizations in shapetracker.py * revert micro-optimizations in assertions * make mypy happy * list comp instead of map in get_unsafe_resize_offset * list comp instead of map in get_unsafe_resize_offset
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from enum import Enum, auto
|
||||
from typing import Tuple, Union, List, Optional, Dict, Callable
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode
|
||||
|
||||
@@ -11,8 +11,8 @@ class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PA
|
||||
|
||||
def check_no_mul(test, var):
|
||||
if test == var: return True
|
||||
if isinstance(test, SumNode): return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
|
||||
if isinstance(test, ModNode) and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
if test.__class__ is SumNode: return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
|
||||
if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@@ -27,22 +27,33 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all([s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
new_strides = []
|
||||
for stride, shp in zip(strides, shape):
|
||||
if shp != 1: new_strides.append(stride)
|
||||
else: new_strides.append(0)
|
||||
return tuple(new_strides)
|
||||
|
||||
class View:
|
||||
__slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous"
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None):
|
||||
self.shape, self.strides, self.offset = shape, tuple(stride if shp != 1 else 0 for stride,shp in zip(strides, shape)), offset
|
||||
self.shape, self.offset = shape, offset
|
||||
self.strides = filter_strides(shape, strides)
|
||||
self.mask = mask
|
||||
self.shape_strides = to_shape_strides(self.shape, self.strides)
|
||||
self.contiguous: bool = self.offset == 0 and is_contiguous(self.shape, self.strides) and mask is None
|
||||
self.shape_strides = to_shape_strides(shape, self.strides)
|
||||
self.contiguous: bool = offset == 0 and is_contiguous(shape, self.strides) and mask is None
|
||||
|
||||
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})"
|
||||
def key(self): return (self.shape, self.strides, self.offset, self.mask)
|
||||
|
||||
def expr_node_mask(self, idx, valid=None) -> Node:
|
||||
expr = [valid] if valid is not None else []
|
||||
if self.mask is not None:
|
||||
acc = 1
|
||||
for ns,(x,y) in list(zip(self.shape, self.mask))[::-1]:
|
||||
for ns,(x,y) in reversed(list(zip(self.shape, self.mask))):
|
||||
base = ((idx//acc) % ns)
|
||||
expr += [base >= x, base < y]
|
||||
acc *= ns
|
||||
@@ -52,7 +63,7 @@ class View:
|
||||
assert len(idxs) == len(self.shape), "need an idx for all dimensions"
|
||||
acc = 1
|
||||
ret = []
|
||||
for tidx,d in list(zip(idxs, self.shape))[::-1]:
|
||||
for tidx,d in reversed(list(zip(idxs, self.shape))):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
@@ -62,7 +73,7 @@ class View:
|
||||
if idx is None: idx = Variable('idx', 0, prod(self.shape))
|
||||
ret = [Variable.num(self.offset)]
|
||||
acc = 1
|
||||
for d,s in self.shape_strides[::-1]:
|
||||
for d,s in reversed(self.shape_strides):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
@@ -76,7 +87,7 @@ class View:
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return tuple(st if s != 1 else 0 for st, s in zip(strides, shape))
|
||||
return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
@@ -92,35 +103,74 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
|
||||
if s == 1:
|
||||
new_strides.append(0) # all shape 1 can have stride 0
|
||||
elif isinstance(this_dim, NumNode) and this_dim.b == 0:
|
||||
elif this_dim.__class__ == NumNode and this_dim.b == 0:
|
||||
new_strides.append(0)
|
||||
elif isinstance(this_dim, Variable):
|
||||
elif this_dim.__class__ == Variable:
|
||||
new_strides.append(1)
|
||||
elif isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
elif this_dim.__class__ == MulNode and cast(MulNode, this_dim).a.__class__ == Variable:
|
||||
new_strides.append(this_dim.b)
|
||||
else:
|
||||
if DEBUG >= 4: print("can't simplify", s, this_dim.render())
|
||||
break
|
||||
return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
|
||||
new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
|
||||
new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
||||
new_mask_tuple = None
|
||||
if mask:
|
||||
for x,y in zip(shape, mask):
|
||||
if x == 1 and y != (0, 1):
|
||||
new_mask_tuple = tuple([(0,0) for _ in new_shape])
|
||||
break
|
||||
else:
|
||||
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
|
||||
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
||||
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
|
||||
|
||||
new_view = View(new_shape, strides_for_shape(new_shape))
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
else:
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_pad_args(shape, arg: Tuple[Tuple[int, int], ...]):
|
||||
return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_unsafe_resize_offset(strides, arg):
|
||||
return sum([s * x[0] for s, x in zip(strides,arg)])
|
||||
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else (shape.views[:] if isinstance(shape, ShapeTracker) else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.shape, self.views[:])
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ == ShapeTracker else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[-1].contiguous
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[int, ...]: return tuple(map(View.key, self.views))
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.shape, self.views[-1].strides) if st != 0])
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
|
||||
def unit_stride_axes(self) -> List[int]:
|
||||
ret, acc = [], 1
|
||||
for j,s in list(enumerate(self.shape))[::-1]:
|
||||
for j,s in reversed(list(enumerate(self.shape))):
|
||||
if s == 1: continue
|
||||
var = Variable('idx', 0, s-1)
|
||||
this_dim = self.expr_node(var*acc)
|
||||
@@ -129,7 +179,7 @@ class ShapeTracker:
|
||||
return ret
|
||||
|
||||
def _expr_idx(self, idx, valid):
|
||||
for v in self.views[0:-1][::-1]:
|
||||
for v in reversed(self.views[0:-1]):
|
||||
valid = v.expr_node_mask(idx, valid)
|
||||
idx = v.expr_node(idx)
|
||||
return idx, valid
|
||||
@@ -149,85 +199,72 @@ class ShapeTracker:
|
||||
return self._expr_idx(idx, valid)
|
||||
|
||||
def expr_node(self, idx='idx'):
|
||||
if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
if idx.__class__ == str: idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
|
||||
|
||||
def needs_valid(self) -> bool:
|
||||
return any(v.mask is not None for v in self.views)
|
||||
return any([v.mask is not None for v in self.views])
|
||||
|
||||
# *** under this line are the movement ops ***
|
||||
|
||||
def __unsafe_resize(self, arg: Tuple[Tuple[int, int], ...], mask=None):
|
||||
offset = sum([self.views[-1].strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
offset = get_unsafe_resize_offset(self.views[-1].strides, arg)
|
||||
if self.views[-1].mask:
|
||||
# move the old mask
|
||||
nmask = tuple((max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg))
|
||||
nmask = tuple([(max(mx-ax, 0), min(my-ax, ay-ax)) for (mx,my),(ax,ay) in zip(self.views[-1].mask, arg)])
|
||||
# merge the masks if we have two
|
||||
mask = tuple((max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)) if mask is not None else nmask
|
||||
self.views[-1] = View(tuple(y-x for x,y in arg), self.views[-1].strides, self.views[-1].offset+offset, mask)
|
||||
mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
self.views[-1] = View(tuple([y-x for x,y in arg]), self.views[-1].strides, self.views[-1].offset+offset, mask)
|
||||
|
||||
def pad(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape)
|
||||
if all(b==0 and e==0 for b,e in arg): return self
|
||||
zvarg = tuple((-b,s+e) for s,(b,e) in zip(self.shape, arg))
|
||||
self.__unsafe_resize(zvarg, mask=tuple((b,s+b) for s,(b,_) in zip(self.shape, arg)))
|
||||
if any([b or e for b, e in arg]):
|
||||
zvarg, mask = get_pad_args(self.shape, arg)
|
||||
self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
return self
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
assert all((b>=0 and e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape)
|
||||
self.__unsafe_resize(arg)
|
||||
return self
|
||||
|
||||
def expand(self, new_shape: Tuple[int, ...]):
|
||||
def expand(self, new_shape: Tuple[int, ...]) -> ShapeTracker:
|
||||
assert len(new_shape) == len(self.views[-1].shape)
|
||||
assert all(isinstance(x, int) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.views[-1].strides)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
mask = tuple((((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)) if self.views[-1].mask else None
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.views[-1].mask, self.shape, new_shape)]) if self.views[-1].mask else None
|
||||
self.views[-1] = View(new_shape, self.views[-1].strides, self.views[-1].offset, mask)
|
||||
return self
|
||||
|
||||
def reshape(self, new_shape: Tuple[int, ...]):
|
||||
if self.shape == new_shape: return self
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
if tuple(x for x in self.shape if x != 1) == tuple(x for x in new_shape if x != 1):
|
||||
old_strides = [y for x,y in zip(self.shape, self.views[-1].strides) if x != 1]
|
||||
new_strides_tuple = tuple(0 if x == 1 else old_strides.pop(0) for x in new_shape)
|
||||
new_mask_tuple = None
|
||||
if self.views[-1].mask:
|
||||
if any(y!=(0,1) for x,y in zip(self.shape, self.views[-1].mask) if x == 1):
|
||||
# mask it all out!
|
||||
new_mask_tuple = tuple((0,0) for _ in new_shape)
|
||||
else:
|
||||
old_mask = [y for x,y in zip(self.shape, self.views[-1].mask) if x != 1]
|
||||
new_mask_tuple = tuple((0,1) if x == 1 else old_mask.pop(0) for x in new_shape)
|
||||
self.views[-1] = View(new_shape, new_strides_tuple, self.views[-1].offset, new_mask_tuple)
|
||||
return self
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
if (merged_view := merge_views(self.views[-1], view)) is not None: self.views[-1] = merged_view
|
||||
else:
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {self} -> {new_shape}")
|
||||
self.views.append(view)
|
||||
new_view, extra = _reshape(self.views[-1], new_shape)
|
||||
if extra: self.views.append(new_view)
|
||||
else: self.views[-1] = new_view
|
||||
return self
|
||||
|
||||
def permute(self, axis: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
self.views[-1] = View(tuple(self.shape[a] for a in axis), tuple(self.views[-1].strides[a] for a in axis), self.views[-1].offset, tuple(self.views[-1].mask[a] for a in axis) if self.views[-1].mask is not None else None)
|
||||
self.views[-1] = View(tuple([self.views[-1].shape[a] for a in axis]), tuple([self.views[-1].strides[a] for a in axis]), self.views[-1].offset, tuple([self.views[-1].mask[a] for a in axis]) if self.views[-1].mask is not None else None)
|
||||
return self
|
||||
|
||||
# except for the negative case, you can build this from the others. invertible in the negative case
|
||||
def stride(self, mul: Tuple[int, ...]):
|
||||
assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}"
|
||||
strides = tuple(z*m for z,m in zip(self.views[-1].strides, mul))
|
||||
new_shape = tuple((s+(abs(m)-1))//abs(m) for s,m in zip(self.shape, mul))
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.shape, self.views[-1].strides, mul) if m < 0])
|
||||
mask = tuple((((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.shape, mul)) if self.views[-1].mask is not None else None
|
||||
strides = tuple([z*m for z,m in zip(self.views[-1].strides, mul)])
|
||||
new_shape = tuple([(s+(abs(m)-1))//abs(m) for s,m in zip(self.views[-1].shape, mul)])
|
||||
offset = sum([(s-1)*z for s,z,m in zip(self.views[-1].shape, self.views[-1].strides, mul) if m < 0])
|
||||
mask = tuple([(((mx if m > 0 else s-my)+(abs(m)-1))//abs(m), ((my if m > 0 else s-mx)+(abs(m)-1))//abs(m)) for (mx,my),s,m in zip(self.views[-1].mask, self.views[-1].shape, mul)]) if self.views[-1].mask is not None else None
|
||||
self.views[-1] = View(new_shape, strides, self.views[-1].offset + offset, mask)
|
||||
return self
|
||||
|
||||
# *** entry point for external ***
|
||||
|
||||
def movement_op(self, op, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker:
|
||||
def movement_op(self, op: MovementOps, arg:Union[Tuple[int, ...], Tuple[Tuple[int, int], ...]]) -> ShapeTracker:
|
||||
assert isinstance(arg, tuple) and (len(arg) == len(self.shape) or op == MovementOps.RESHAPE), f"arg {arg} for {op} doesn't match dim of shape {self.shape}"
|
||||
dispatch[op](self, arg)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user