refactor: expr_view on View (#5315)

This commit is contained in:
Roelof van Dijk
2024-07-08 20:47:34 +02:00
committed by GitHub
parent 2349d837fb
commit 053c706961
2 changed files with 13 additions and 13 deletions

View File

@@ -3,18 +3,9 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View, strides_for_shape
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
vexpr: List[Node] = [valid] if valid is not None else []
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if sh != 1 and st != 0: iexpr.append(idx*st)
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
return Node.sum(iexpr), Node.ands(vexpr)
@dataclass(frozen=True)
class ShapeTracker:
views: Tuple[View, ...]
@@ -86,7 +77,7 @@ class ShapeTracker:
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
idx, valid = _expr_view(self.views[-1], idxs)
idx, valid = self.views[-1].expr(idxs)
for view in reversed(self.views[0:-1]):
if valid.max == 0: return NumNode(-1), valid
view = view.minify()
@@ -94,7 +85,7 @@ class ShapeTracker:
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _expr_view(view, idxs[::-1], valid)
idx, valid = view.expr(idxs[::-1], valid)
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
return idx, valid

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer, create_lt_node, create_ge_node
@functools.lru_cache(maxsize=None)
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@@ -309,3 +309,12 @@ class View:
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
return None
def expr(self, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
iexpr: List[Node] = [NumNode(self.offset) if isinstance(self.offset, int) else self.offset]
vexpr: List[Node] = [valid] if valid is not None else []
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
if sh != 1 and st != 0: iexpr.append(idx*st)
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
return Node.sum(iexpr), Node.ands(vexpr)