diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 739bdd128f..f24f78a775 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -1,11 +1,10 @@ -from typing import Iterator +from typing import Iterator, Sequence import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType from tinygrad.uop.symbolic import sym from tinygrad.helpers import argsort, all_same, Context -from tinygrad.uop.ops import graph_rewrite, sint, AxisType ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, @@ -43,7 +42,6 @@ class BufferizeOpts: class IndexingContext: realize_map: dict[UOp, None] = field(default_factory=dict) range_map: dict[UOp, tuple[list[UOp], list[UOp]]] = field(default_factory=dict) - pads_gate: dict[UOp, UOp] = field(default_factory=dict) # create ranges range_idx: Iterator[int] = field(default_factory=itertools.count) @@ -59,7 +57,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL): if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) elif s in ctx.realize_map: - new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag) + new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag) if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0]) new_srcs.append(new_src) # NOTE: do we need this? @@ -67,7 +65,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp): def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp): if x not in ctx.range_map: return None - ret = ctx.pads_gate[x].where(x.src[0], UOp.const(x.dtype, 0)) + valid: UOp = functools.reduce(operator.and_, [r.get_valid() for r in ctx.range_map[x][0]], UOp.const(dtypes.bool, True)) + ret = valid.where(x.src[0], UOp.const(x.dtype, 0)) ctx.range_map[ret] = ctx.range_map[x] return ret @@ -103,6 +102,34 @@ pm_apply_rangeify = PatternMatcher([ (UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None), ]) +# this is the definition of the movement ops +def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]: + match x.op: + case Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)] + case Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)] + case Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)] + case Ops.EXPAND: rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)] + case Ops.PAD: + # TODO: why is multiple graph_rewrites faster than one here? + with Context(TRACK_MATCH_STATS=0): + rngs = [r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh-e))).where(r-s, UOp.invalid()), sym) + for r,sh,(s,e) in zip(rngs, x.shape, x.arg)] + case Ops.RESHAPE: + acc = 1 + axes_in:list[UOp] = [] + for s,src in list(zip(x.shape, rngs))[::-1]: + axes_in.append(acc*src) + acc *= s + combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0)) + axes_out:list[UOp] = [] + for s in x.src[0].shape[::-1]: + axes_out.append(combined_axes % s) + combined_axes //= s + # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code + rngs = list(UOp.sink(*axes_out[::-1]).simplify().src) + case _: raise RuntimeError(f"{x.op} is not a MovementOp") + return rngs + def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: rctx = IndexingContext() @@ -177,39 +204,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: rngs = out_rngs # rngs is the input ranges - # apply movement ops. this is the definition of them - if x.op is Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)] - if x.op is Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)] - if x.op is Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)] - if x.op is Ops.EXPAND: - rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)] - ending_ranges[x] = True - if x.op is Ops.PAD: - rngs = rngs[:] - bigwhere = UOp.const(dtypes.bool, True) - for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)): - if s == 0 and e == 0: continue - where = UOp.const(dtypes.bool, True) - if resolve(e > 0): where = where & (rngs[i] < (sh-e)) - if resolve(s > 0): where = where & (rngs[i] >= s) - bigwhere = bigwhere & where - with Context(TRACK_MATCH_STATS=0): - rngs[i] = graph_rewrite(where.where(rngs[i]-s, UOp.invalid()), sym) - # PAD is replaced with a WHERE in the big graph to inject the 0s at the right place - rctx.pads_gate[x] = bigwhere.simplify() - if x.op is Ops.RESHAPE: - acc = 1 - to_sum = [] - for s,src in list(zip(x.shape, rngs))[::-1]: - to_sum.append(acc*src) - acc *= s - mish = sum(to_sum, start=UOp.const(dtypes.index, 0)) - ret:list[UOp] = [] - for s in x.src[0].shape[::-1]: - ret.append(mish % s) # NOTE: simplify will turn this to CONST - mish //= s - # this simplify is doing a lot of heavy lifting. this is the replacement for the view merger in RESHAPE - rngs = list(UOp.sink(*ret[::-1]).simplify().src) + # apply movement ops + if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs) + if x.op is Ops.EXPAND: ending_ranges[x] = True # REDUCE_AXIS creates ranges for the axes it is reducing if x.op is Ops.REDUCE_AXIS: diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 04bf7d0b25..081177c81d 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,13 +2,13 @@ from typing import cast from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo -from tinygrad.uop.symbolic import sym, symbolic_simple -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP +from tinygrad.uop.symbolic import symbolic_simple +from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP from tinygrad.helpers import Metadata from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented from tinygrad.codegen.opt import Opt -from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext +from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op # creation can recurse a lot import sys @@ -100,69 +100,11 @@ earliest_rewrites = PatternMatcher([ # ***************** # 3a. rangeify (movement) -# NOTE: this can be deleted after the cleanup is refactored - -def map_reshape(idx:UOp, r:UOp): - acc = 1 - to_sum = [] - for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]: - to_sum.append(acc*src) - acc *= s - mish = sum(to_sum, start=UOp.const(dtypes.index, 0)) - ret:list[UOp] = [] - for s in r.src[0].shape[::-1]: - ret.append(mish % s) # NOTE: simplify will turn this to CONST - mish //= s - tret = UOp.sink(*ret[::-1]).simplify().src - return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg) - -def map_pad(idx:UOp, r:UOp): - ret = list(idx.src[1:]) - bigwhere = UOp.const(dtypes.bool, True) - for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)): - if s == 0 and e == 0: continue - where = UOp.const(dtypes.bool, True) - if resolve(e > 0): where = where & (ret[i] < (sh-e)) - if resolve(s > 0): where = where & (ret[i] >= s) - bigwhere = bigwhere & where - with Context(TRACK_MATCH_STATS=0): - ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym) - # PAD is with 0 - return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0)) - -def map_expand(r:UOp, idx:UOp): - new_rngs = [] - ending_ranges = [] - non_ending_ranges = [] - for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape): - axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE] - if resolve(x==y, False): - non_ending_ranges.extend(axis_to_range) - new_rngs.append(a) - else: - ending_ranges.extend(axis_to_range) - new_rngs.append(a.const_like(0)) - # if RANGEIFY >= 2, we are aggressive about not ending ranges - if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges] - # if RANGEIFY=1, if it's ending at all we end it - else: ending_ranges = [x.arg for x in ending_ranges] - if idx.arg is not None: ending_ranges.append(idx.arg) - return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None) +# movement op on INDEX as a PatternMatcher pm_mops = PatternMatcher([ - # this is like the definitions of these - (UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), - lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)), - (UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), - lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)), - (UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), - lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)), - # expand needs to end ranges - (UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand), - # reshape does a lot of symbolic stuff - (UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape), - # pad adds min and max - (UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad), + (UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), + lambda r,idx: r.src[0].index(*apply_movement_op(r, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), ]) # *****************