View as namedtuple, cached methods (#1075)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-07-09 23:26:02 +02:00
committed by GitHub
parent 1eb0e0cb3f
commit e27f098946

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from enum import Enum, auto
import functools
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
@@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if len(shape) > 0 else []
for i in range(1, len(shape)):
@@ -18,7 +18,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
ret[-1] = (ret[-1][0] * shape[i], strides[i])
else:
ret.append((shape[i], strides[i]))
return ret
return tuple(ret)
@functools.lru_cache(maxsize=None)
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
@@ -27,17 +27,22 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
class View:
__slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous"
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None):
self.shape, self.offset = shape, offset
self.strides = filter_strides(shape, strides)
self.mask = mask
self.shape_strides = to_shape_strides(shape, self.strides)
self.contiguous: bool = offset == 0 and is_contiguous(shape, self.strides) and mask is None
class ViewInternal(NamedTuple):
shape:Tuple[int, ...]
strides:Tuple[int, ...]
offset:int
mask:Optional[Tuple[Tuple[int, int]]]
contiguous:bool
shape_strides:Tuple[Tuple[int, int], ...]
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})"
def key(self): return (self.shape, self.strides, self.offset, self.mask)
@functools.lru_cache(maxsize=None)
class View(ViewInternal):
def __new__(cls, shape, strides=None, offset=0, mask=None):
strides_from_shape = strides_for_shape(shape)
strides = strides_from_shape if not strides else filter_strides(shape, strides)
contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None
return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides))
def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__()
def expr_node_mask(self, idx, valid=None) -> Node:
expr = [valid] if valid is not None else []
@@ -49,19 +54,10 @@ class View:
acc *= ns
return Variable.ands(expr)
def idxs_to_idx(self, idxs):
assert len(idxs) == len(self.shape), "need an idx for all dimensions"
acc = 1
ret = []
for tidx,d in reversed(list(zip(idxs, self.shape))):
ret.append(tidx * acc)
acc *= d
return Variable.sum(ret)
# generate an expression if you have a single idx variable
def expr_node(self, idx=None) -> Node:
if idx is None: idx = Variable('idx', 0, prod(self.shape))
ret: List[Node] = [Variable.num(self.offset)]
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
acc = 1
for d,s in reversed(self.shape_strides):
ret.append(((idx//acc)%d)*s)
@@ -69,9 +65,19 @@ class View:
return Variable.sum(ret)
# generate an expression if you have a variable or expression for each index
def expr_idxs(self, idxs):
def expr_idxs(self, idxs) -> Node:
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
@functools.lru_cache(maxsize=None)
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
assert len(idxs) == len(shape), "need an idx for all dimensions"
acc = 1
ret = []
for tidx,d in reversed(list(zip(idxs, shape))):
ret.append(tidx * acc)
acc *= d
return Variable.sum(ret)
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
@@ -90,10 +96,10 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
mst = ShapeTracker(vm1.shape, [vm2, vm1])
strides = mst.real_strides()
if None in strides: return None
return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask)
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
@functools.lru_cache(maxsize=None)
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
# check if this is adding or removing 1s (only)
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
@@ -120,7 +126,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
return new_view, True
@functools.lru_cache(maxsize=None)
def get_pad_args(shape, arg: Tuple[Tuple[int, int], ...]):
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)])
@functools.lru_cache(maxsize=None)
@@ -141,7 +147,7 @@ class ShapeTracker:
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
@property
def key(self) -> Tuple[int, ...]: return tuple(map(View.key, self.views))
def key(self) -> Tuple[View, ...]: return tuple(self.views)
# this is the real size (ish)
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
@@ -187,8 +193,8 @@ class ShapeTracker:
def expr_idxs(self, idxs=None):
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx = self.views[-1].expr_idxs(idxs)
valid = self.views[-1].expr_node_mask(self.views[-1].idxs_to_idx(idxs))
idx = self.views[-1].expr_idxs(tuple(idxs))
valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs)))
return self._expr_idx(idx, valid)
def expr_node(self, idx='idx'):