cache indexed uops in st [pr] (#8008)

* cache indexed uops in st [pr]

* remove arg from range
This commit is contained in:
George Hotz
2024-12-03 21:27:07 +08:00
committed by GitHub
parent e44183647f
commit 09eac42fd6
5 changed files with 48 additions and 41 deletions

View File

@@ -31,13 +31,14 @@ if __name__ == "__main__":
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
kernels: List[Kernel] = []
with Timing(f"***** model opts({len(asts):2d}) in "):
for ast in asts:
k = Kernel(ast)
if BEAM:
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
elif NOOPT: pass
else: k.hand_coded_optimizations()
kernels.append(k)
with Profiling(PROFILE >= 3):
for ast in asts:
k = Kernel(ast)
if BEAM:
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
elif NOOPT: pass
else: k.hand_coded_optimizations()
kernels.append(k)
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):

View File

@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UO
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=n, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):

View File

@@ -349,7 +349,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@staticmethod
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
@@ -1177,7 +1177,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
renderer = PatternMatcher([
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),

View File

@@ -1,12 +1,44 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
@functools.lru_cache(None)
def views_to_indexed_uops(views: Tuple[View, ...], _idxs:Optional[Tuple[UOp, ...]]=None) -> Tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
return idx, valid
@functools.lru_cache(None)
def views_to_real_strides(views: Tuple[View, ...], ignore_valid=False) -> Tuple[Optional[sint], ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
ret: List[Optional[sint]] = [None] * len(views[-1].shape)
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
# TODO: always apply these in to_indexed_uops?
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
@dataclass(frozen=True, order=True)
class ShapeTracker:
views: Tuple[View, ...]
@@ -41,17 +73,8 @@ class ShapeTracker:
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
idx, valid = self.views[-1].to_indexed_uops(_idxs)
for view in reversed(self.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
return idx, valid
def to_indexed_uops(self, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None) -> Tuple[UOp, UOp]:
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
def real_size(self) -> int:
if 0 in self.shape: return 0
@@ -69,29 +92,12 @@ class ShapeTracker:
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
ret: List[Optional[sint]] = [None] * len(self.shape)
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
# TODO: always apply these in to_indexed_uops?
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
for c in split_uop(idx, Ops.ADD):
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
used_ranges = [x.arg[0] for x in idx.toposort if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg[0] for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.to_indexed_uops()
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:

View File

@@ -95,7 +95,7 @@ class View:
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
def __lt__(self, o:View): return self.t < o.t
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):