mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cache indexed uops in st [pr] (#8008)
* cache indexed uops in st [pr] * remove arg from range
This commit is contained in:
15
test/external/external_benchmark_schedule.py
vendored
15
test/external/external_benchmark_schedule.py
vendored
@@ -31,13 +31,14 @@ if __name__ == "__main__":
|
||||
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
|
||||
kernels: List[Kernel] = []
|
||||
with Timing(f"***** model opts({len(asts):2d}) in "):
|
||||
for ast in asts:
|
||||
k = Kernel(ast)
|
||||
if BEAM:
|
||||
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
|
||||
elif NOOPT: pass
|
||||
else: k.hand_coded_optimizations()
|
||||
kernels.append(k)
|
||||
with Profiling(PROFILE >= 3):
|
||||
for ast in asts:
|
||||
k = Kernel(ast)
|
||||
if BEAM:
|
||||
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
|
||||
elif NOOPT: pass
|
||||
else: k.hand_coded_optimizations()
|
||||
kernels.append(k)
|
||||
|
||||
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
|
||||
@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UO
|
||||
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=n, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_is_increasing(self):
|
||||
|
||||
@@ -349,7 +349,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False))
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
||||
@@ -1177,7 +1177,7 @@ syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<"
|
||||
Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg}")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.NEG, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
|
||||
@@ -1,12 +1,44 @@
|
||||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def views_to_indexed_uops(views: Tuple[View, ...], _idxs:Optional[Tuple[UOp, ...]]=None) -> Tuple[UOp, UOp]:
|
||||
idx, valid = views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
||||
return idx, valid
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def views_to_real_strides(views: Tuple[View, ...], ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
||||
ret: List[Optional[sint]] = [None] * len(views[-1].shape)
|
||||
idx, valid = (graph_rewrite(u, symbolic_flat) for u in views_to_indexed_uops(views))
|
||||
# TODO: always apply these in to_indexed_uops?
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
||||
for c in split_uop(idx, Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
used_ranges = [x.arg for x in idx.toposort if x.op is Ops.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
for masked_axis in [x.arg for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
||||
return tuple(ret)
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
@@ -41,17 +73,8 @@ class ShapeTracker:
|
||||
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
|
||||
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
||||
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
||||
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(self.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
||||
return idx, valid
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None) -> Tuple[UOp, UOp]:
|
||||
return views_to_indexed_uops(self.views, tuple(_idxs) if _idxs is not None else None)
|
||||
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
@@ -69,29 +92,12 @@ class ShapeTracker:
|
||||
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
|
||||
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
ret: List[Optional[sint]] = [None] * len(self.shape)
|
||||
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
||||
# TODO: always apply these in to_indexed_uops?
|
||||
if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid
|
||||
if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat)
|
||||
for c in split_uop(idx, Ops.ADD):
|
||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
||||
used_ranges = [x.arg[0] for x in idx.toposort if x.op is Ops.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
for masked_axis in [x.arg[0] for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
||||
return tuple(ret)
|
||||
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
||||
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
||||
@@ -95,7 +95,7 @@ class View:
|
||||
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
|
||||
def __lt__(self, o:View): return self.t < o.t
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = sint_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
|
||||
Reference in New Issue
Block a user