move to_indexed_uops to uops (#7011)

* move to_indexed_uops to uops

* UOp.range
This commit is contained in:
George Hotz
2024-10-12 18:20:57 +08:00
committed by GitHub
parent 5ae2de9845
commit b737ee5bac
4 changed files with 22 additions and 21 deletions

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
import functools, itertools, operator
from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve

View File

@@ -311,8 +311,9 @@ class UOp(MathTrait):
if self in dvars: return dvars[self]
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}

View File

@@ -6,18 +6,7 @@ from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve, _get_chain, symbolic_flat
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*variable_to_uop(st)
if m is not None:
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat
@dataclass(frozen=True)
class ShapeTracker:
@@ -55,17 +44,14 @@ class ShapeTracker:
def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \
if _idxs is None else _idxs
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
idx, valid = self.views[-1].to_indexed_uops(_idxs)
for view in reversed(self.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for _d in reversed(view.shape):
d = variable_to_uop(_d)
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
return idx, valid
def real_size(self) -> int:

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union
from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp
from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer
@@ -82,6 +83,8 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
offs -= here * stride
return result
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
@dataclass(frozen=True)
class View:
shape:Tuple[sint, ...]
@@ -90,6 +93,16 @@ class View:
mask:Optional[Tuple[Tuple[sint, sint], ...]]
contiguous:bool
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
iexpr = variable_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
if m is not None:
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
return iexpr, vexpr
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self) -> int:
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.