diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index b86178f543..3fb0b2eb81 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -2,7 +2,7 @@ from __future__ import annotations from enum import Enum, auto import functools -from typing import Dict, Tuple, Union, List, Optional, Callable, cast +from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode @@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 @functools.lru_cache(maxsize=None) -def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]: +def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: assert len(shape) == len(strides) ret = [(shape[0], strides[0])] if len(shape) > 0 else [] for i in range(1, len(shape)): @@ -18,7 +18,7 @@ def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tup ret[-1] = (ret[-1][0] * shape[i], strides[i]) else: ret.append((shape[i], strides[i])) - return ret + return tuple(ret) @functools.lru_cache(maxsize=None) def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: return all(s1 == s2 or s == 1 for s,s1,s2 in zip(shape, strides, strides_for_shape(shape))) @@ -27,17 +27,22 @@ def is_contiguous(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> bool: retur def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape)) -class View: - __slots__ = "shape", "strides", "offset", "mask", "shape_strides", "contiguous" - def __init__(self, shape:Tuple[int, ...], strides:Tuple[int, ...], offset:int=0, mask:Optional[Tuple[Tuple[int, int], ...]]=None): - self.shape, self.offset = shape, offset - self.strides = filter_strides(shape, strides) - self.mask = mask - self.shape_strides = to_shape_strides(shape, self.strides) - self.contiguous: bool = offset == 0 and is_contiguous(shape, self.strides) and mask is None +class ViewInternal(NamedTuple): + shape:Tuple[int, ...] + strides:Tuple[int, ...] + offset:int + mask:Optional[Tuple[Tuple[int, int]]] + contiguous:bool + shape_strides:Tuple[Tuple[int, int], ...] - def __repr__(self): return f"View({self.shape}, {self.strides}, {self.offset}, {self.mask})" - def key(self): return (self.shape, self.strides, self.offset, self.mask) +@functools.lru_cache(maxsize=None) +class View(ViewInternal): + def __new__(cls, shape, strides=None, offset=0, mask=None): + strides_from_shape = strides_for_shape(shape) + strides = strides_from_shape if not strides else filter_strides(shape, strides) + contiguous = offset == 0 and is_contiguous(shape, strides) and mask is None + return super().__new__(cls, shape, strides, offset, mask, contiguous, to_shape_strides(shape, strides)) + def __init__(self, shape, strides=None, offset=0, mask=None, contiguous=False, shape_strides=()): super().__init__() def expr_node_mask(self, idx, valid=None) -> Node: expr = [valid] if valid is not None else [] @@ -49,19 +54,10 @@ class View: acc *= ns return Variable.ands(expr) - def idxs_to_idx(self, idxs): - assert len(idxs) == len(self.shape), "need an idx for all dimensions" - acc = 1 - ret = [] - for tidx,d in reversed(list(zip(idxs, self.shape))): - ret.append(tidx * acc) - acc *= d - return Variable.sum(ret) - # generate an expression if you have a single idx variable def expr_node(self, idx=None) -> Node: if idx is None: idx = Variable('idx', 0, prod(self.shape)) - ret: List[Node] = [Variable.num(self.offset)] + ret: List[Node] = [Variable.num(self.offset)] if self.offset else [] acc = 1 for d,s in reversed(self.shape_strides): ret.append(((idx//acc)%d)*s) @@ -69,9 +65,19 @@ class View: return Variable.sum(ret) # generate an expression if you have a variable or expression for each index - def expr_idxs(self, idxs): + def expr_idxs(self, idxs) -> Node: assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}" return Variable.sum([Variable.num(self.offset)] + [idx*st for idx,sh,st in zip(idxs, self.shape, self.strides) if sh != 1 and st != 0]) + +@functools.lru_cache(maxsize=None) +def idxs_to_idx(shape:Tuple[int, ...], idxs) -> Node: + assert len(idxs) == len(shape), "need an idx for all dimensions" + acc = 1 + ret = [] + for tidx,d in reversed(list(zip(idxs, shape))): + ret.append(tidx * acc) + acc *= d + return Variable.sum(ret) @functools.lru_cache(maxsize=None) def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: @@ -90,10 +96,10 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]: mst = ShapeTracker(vm1.shape, [vm2, vm1]) strides = mst.real_strides() if None in strides: return None - return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask) + return View(vm1.shape, strides, mst.real_offset(), vm1.mask) @functools.lru_cache(maxsize=None) -def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]: +def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset # check if this is adding or removing 1s (only) # NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional) @@ -120,7 +126,7 @@ def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]: return new_view, True @functools.lru_cache(maxsize=None) -def get_pad_args(shape, arg: Tuple[Tuple[int, int], ...]): +def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]): return tuple([(-b,s+e) for s,(b,e) in zip(shape, arg)]), tuple([(b,s+b) for s,(b,_) in zip(shape, arg)]) @functools.lru_cache(maxsize=None) @@ -141,7 +147,7 @@ class ShapeTracker: def shape(self) -> Tuple[int, ...]: return self.views[-1].shape @property - def key(self) -> Tuple[int, ...]: return tuple(map(View.key, self.views)) + def key(self) -> Tuple[View, ...]: return tuple(self.views) # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) @@ -187,8 +193,8 @@ class ShapeTracker: def expr_idxs(self, idxs=None): if idxs is None: idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] - idx = self.views[-1].expr_idxs(idxs) - valid = self.views[-1].expr_node_mask(self.views[-1].idxs_to_idx(idxs)) + idx = self.views[-1].expr_idxs(tuple(idxs)) + valid = self.views[-1].expr_node_mask(idxs_to_idx(self.views[-1].shape, tuple(idxs))) return self._expr_idx(idx, valid) def expr_node(self, idx='idx'):