mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
View as namedtuple, cached methods (#1075)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
from enum import Enum, auto
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])] if len(shape) > 0 else []
|
||||
for i in range(1, len(shape)):
|
||||
@@ -18,7 +18,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
return ret
|
||||
return tuple(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape)))
|
||||
@@ -27,17 +27,22 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
|
||||
|
||||
class View:
|
||||
__slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous"
|
||||
def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None):
|
||||
self.shape, self.offset = shape, offset
|
||||
self.strides = filter_strides(shape, strides)
|
||||
self.mask = mask
|
||||
self.shape_strides = to_shape_strides(shape, self.strides)
|
||||
self.contiguous: bool = offset == 0 and is_contiguous(shape, self.strides) and mask is None
|
||||
class ViewInternal(NamedTuple):
|
||||
shape:Tuple[int, ...]
|
||||
strides:Tuple[int, ...]
|
||||
offset:int
|
||||
mask:Optional[Tuple[Tuple[int, int]]]
|
||||
contiguous:bool
|
||||
shape_strides:Tuple[Tuple[int, int], ...]
|
||||
|
||||
def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})"
|
||||
def key(self): return (self.shape, self.strides, self.offset, self.mask)
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class View(ViewInternal):
|
||||
def __new__(cls, shape, strides=None, offset=0, mask=None):
|
||||
strides_from_shape = strides_for_shape(shape)
|
||||
strides = strides_from_shape if not strides else filter_strides(shape, strides)
|
||||
contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None
|
||||
return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides))
|
||||
def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__()
|
||||
|
||||
def expr_node_mask(self, idx, valid=None) -> Node:
|
||||
expr = [valid] if valid is not None else []
|
||||
@@ -49,19 +54,10 @@ class View:
|
||||
acc *= ns
|
||||
return Variable.ands(expr)
|
||||
|
||||
def idxs_to_idx(self, idxs):
|
||||
assert len(idxs) == len(self.shape), "need an idx for all dimensions"
|
||||
acc = 1
|
||||
ret = []
|
||||
for tidx,d in reversed(list(zip(idxs, self.shape))):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(self, idx=None) -> Node:
|
||||
if idx is None: idx = Variable('idx', 0, prod(self.shape))
|
||||
ret: List[Node] = [Variable.num(self.offset)]
|
||||
ret: List[Node] = [Variable.num(self.offset)] if self.offset else []
|
||||
acc = 1
|
||||
for d,s in reversed(self.shape_strides):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
@@ -69,9 +65,19 @@ class View:
|
||||
return Variable.sum(ret)
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(self, idxs):
|
||||
def expr_idxs(self, idxs) -> Node:
|
||||
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
||||
return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node:
|
||||
assert len(idxs) == len(shape), "need an idx for all dimensions"
|
||||
acc = 1
|
||||
ret = []
|
||||
for tidx,d in reversed(list(zip(idxs, shape))):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -90,10 +96,10 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask)
|
||||
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
@@ -120,7 +126,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
return new_view, True
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_pad_args(shape, arg: Tuple[Tuple[int, int], ...]):
|
||||
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
|
||||
return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@@ -141,7 +147,7 @@ class ShapeTracker:
|
||||
def shape(self) -> Tuple[int, ...]: return self.views[-1].shape
|
||||
|
||||
@property
|
||||
def key(self) -> Tuple[int, ...]: return tuple(map(View.key, self.views))
|
||||
def key(self) -> Tuple[View, ...]: return tuple(self.views)
|
||||
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
@@ -187,8 +193,8 @@ class ShapeTracker:
|
||||
|
||||
def expr_idxs(self, idxs=None):
|
||||
if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx = self.views[-1].expr_idxs(idxs)
|
||||
valid = self.views[-1].expr_node_mask(self.views[-1].idxs_to_idx(idxs))
|
||||
idx = self.views[-1].expr_idxs(tuple(idxs))
|
||||
valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs)))
|
||||
return self._expr_idx(idx, valid)
|
||||
|
||||
def expr_node(self, idx='idx'):
|
||||
|
||||
Reference in New Issue
Block a user