From b737ee5bac40fc4711fda7877c115cc28015167e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 12 Oct 2024 18:20:57 +0800 Subject: [PATCH] move to_indexed_uops to uops (#7011) * move to_indexed_uops to uops * UOp.range --- tinygrad/codegen/lowerer.py | 3 ++- tinygrad/ops.py | 5 +++-- tinygrad/shape/shapetracker.py | 22 ++++------------------ tinygrad/shape/view.py | 13 +++++++++++++ 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 3cb804769b..da6915dc5a 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -3,7 +3,8 @@ from __future__ import annotations import functools, itertools, operator from dataclasses import dataclass from typing import List, Tuple, cast, Optional -from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import variable_to_uop from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 837b759c0d..d780b30419 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -311,8 +311,9 @@ class UOp(MathTrait): if self in dvars: return dvars[self] return self.replace(src=tuple(x.substitute(dvars) for x in self.src)) @staticmethod - def range(dtype:DType, start:ConstType, end:ConstType, idx:int): - return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,)) + def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int): + return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, + UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx) def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op) @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}} diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7bad9f8890..3284858f36 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -6,18 +6,7 @@ from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve, _get_chain, symbolic_flat - -def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x -def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]: - # TODO: dtypes.realint - iexpr = variable_to_uop(view.offset) - for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)): - if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*variable_to_uop(st) - if m is not None: - if resolve(m[0] != 0): vexpr = vexpr * idx.ge(variable_to_uop(m[0])) - if resolve(m[1] != sh): vexpr = vexpr * idx.lt(variable_to_uop(m[1])) - return iexpr, vexpr +from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat @dataclass(frozen=True) class ShapeTracker: @@ -55,17 +44,14 @@ class ShapeTracker: def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self) def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]: - idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \ - if _idxs is None else _idxs - idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True)) + idx, valid = self.views[-1].to_indexed_uops(_idxs) for view in reversed(self.views[0:-1]): view = view.minify() acc, idxs = 1, [] - for _d in reversed(view.shape): - d = variable_to_uop(_d) + for d in reversed(view.shape): idxs.append((idx//acc)%d) acc *= d - idx, valid = _uop_view(view, idxs[::-1], valid) + idx, valid = view.to_indexed_uops(idxs[::-1], valid) return idx, valid def real_size(self) -> int: diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 719703859b..99fbf7fff0 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -2,6 +2,7 @@ from __future__ import annotations import functools, operator, itertools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast, Union +from tinygrad.dtype import dtypes from tinygrad.ops import resolve, UOp from tinygrad.helpers import prod, all_int, argsort from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer @@ -82,6 +83,8 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: offs -= here * stride return result +def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x + @dataclass(frozen=True) class View: shape:Tuple[sint, ...] @@ -90,6 +93,16 @@ class View: mask:Optional[Tuple[Tuple[sint, sint], ...]] contiguous:bool + def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]: + idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs + iexpr = variable_to_uop(self.offset) + for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)): + if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st + if m is not None: + if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0]) + if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1]) + return iexpr, vexpr + @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def size(self) -> int: # NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.