mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
move to_indexed_uops to uops (#7011)
* move to_indexed_uops to uops * UOp.range
This commit is contained in:
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve
|
||||
|
||||
@@ -311,8 +311,9 @@ class UOp(MathTrait):
|
||||
if self in dvars: return dvars[self]
|
||||
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
|
||||
@@ -6,18 +6,7 @@ from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve, _get_chain, symbolic_flat
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*variable_to_uop(st)
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
||||
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
@@ -55,17 +44,14 @@ class ShapeTracker:
|
||||
def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self)
|
||||
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \
|
||||
if _idxs is None else _idxs
|
||||
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(self.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for _d in reversed(view.shape):
|
||||
d = variable_to_uop(_d)
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
||||
return idx, valid
|
||||
|
||||
def real_size(self) -> int:
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer
|
||||
@@ -82,6 +83,8 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
@@ -90,6 +93,16 @@ class View:
|
||||
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
||||
contiguous:bool
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = variable_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
|
||||
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
|
||||
return iexpr, vexpr
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self) -> int:
|
||||
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
||||
|
||||
Reference in New Issue
Block a user