mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
refactor: expr_view on View (#5315)
This commit is contained in:
@@ -3,18 +3,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, create_lt_node, create_ge_node, sint
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
||||
def _expr_view(view:View, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
iexpr: List[Node] = [NumNode(view.offset) if isinstance(view.offset, int) else view.offset]
|
||||
vexpr: List[Node] = [valid] if valid is not None else []
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if sh != 1 and st != 0: iexpr.append(idx*st)
|
||||
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
||||
return Node.sum(iexpr), Node.ands(vexpr)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
@@ -86,7 +77,7 @@ class ShapeTracker:
|
||||
|
||||
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
||||
idx, valid = _expr_view(self.views[-1], idxs)
|
||||
idx, valid = self.views[-1].expr(idxs)
|
||||
for view in reversed(self.views[0:-1]):
|
||||
if valid.max == 0: return NumNode(-1), valid
|
||||
view = view.minify()
|
||||
@@ -94,7 +85,7 @@ class ShapeTracker:
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _expr_view(view, idxs[::-1], valid)
|
||||
idx, valid = view.expr(idxs[::-1], valid)
|
||||
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
||||
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
||||
return idx, valid
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
@@ -309,3 +309,12 @@ class View:
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
|
||||
return None
|
||||
|
||||
def expr(self, idxs:List[Node], valid:Optional[Node]=None) -> Tuple[Node, Node]:
|
||||
assert len(idxs) == len(self.shape), f"need an idx for all dimensions {idxs} vs {self.shape}"
|
||||
iexpr: List[Node] = [NumNode(self.offset) if isinstance(self.offset, int) else self.offset]
|
||||
vexpr: List[Node] = [valid] if valid is not None else []
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if sh != 1 and st != 0: iexpr.append(idx*st)
|
||||
if m is not None: vexpr += [create_ge_node(idx, m[0]), create_lt_node(idx, m[1])] # idx >= m[0], idx < m[1]
|
||||
return Node.sum(iexpr), Node.ands(vexpr)
|
||||
|
||||
Reference in New Issue
Block a user