mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
refactor to apply_movement_op (#12533)
* refactor to apply_movement_op * new pm_mops is fine * make mypy happy * cleanup apply_movement_op function
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Sequence
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.helpers import argsort, all_same, Context
|
||||
from tinygrad.uop.ops import graph_rewrite, sint, AxisType
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||
@@ -43,7 +42,6 @@ class BufferizeOpts:
|
||||
class IndexingContext:
|
||||
realize_map: dict[UOp, None] = field(default_factory=dict)
|
||||
range_map: dict[UOp, tuple[list[UOp], list[UOp]]] = field(default_factory=dict)
|
||||
pads_gate: dict[UOp, UOp] = field(default_factory=dict)
|
||||
|
||||
# create ranges
|
||||
range_idx: Iterator[int] = field(default_factory=itertools.count)
|
||||
@@ -59,7 +57,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
elif s in ctx.realize_map:
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
@@ -67,7 +65,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
|
||||
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
|
||||
if x not in ctx.range_map: return None
|
||||
ret = ctx.pads_gate[x].where(x.src[0], UOp.const(x.dtype, 0))
|
||||
valid: UOp = functools.reduce(operator.and_, [r.get_valid() for r in ctx.range_map[x][0]], UOp.const(dtypes.bool, True))
|
||||
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
|
||||
@@ -103,6 +102,34 @@ pm_apply_rangeify = PatternMatcher([
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
|
||||
])
|
||||
|
||||
# this is the definition of the movement ops
|
||||
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
|
||||
match x.op:
|
||||
case Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)]
|
||||
case Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
|
||||
case Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
|
||||
case Ops.EXPAND: rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
|
||||
case Ops.PAD:
|
||||
# TODO: why is multiple graph_rewrites faster than one here?
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
rngs = [r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh-e))).where(r-s, UOp.invalid()), sym)
|
||||
for r,sh,(s,e) in zip(rngs, x.shape, x.arg)]
|
||||
case Ops.RESHAPE:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
for s,src in list(zip(x.shape, rngs))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
acc *= s
|
||||
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
|
||||
axes_out:list[UOp] = []
|
||||
for s in x.src[0].shape[::-1]:
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
rngs = list(UOp.sink(*axes_out[::-1]).simplify().src)
|
||||
case _: raise RuntimeError(f"{x.op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
rctx = IndexingContext()
|
||||
|
||||
@@ -177,39 +204,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
|
||||
rngs = out_rngs # rngs is the input ranges
|
||||
|
||||
# apply movement ops. this is the definition of them
|
||||
if x.op is Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)]
|
||||
if x.op is Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
|
||||
if x.op is Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
|
||||
if x.op is Ops.EXPAND:
|
||||
rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
|
||||
ending_ranges[x] = True
|
||||
if x.op is Ops.PAD:
|
||||
rngs = rngs[:]
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if resolve(e > 0): where = where & (rngs[i] < (sh-e))
|
||||
if resolve(s > 0): where = where & (rngs[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
rngs[i] = graph_rewrite(where.where(rngs[i]-s, UOp.invalid()), sym)
|
||||
# PAD is replaced with a WHERE in the big graph to inject the 0s at the right place
|
||||
rctx.pads_gate[x] = bigwhere.simplify()
|
||||
if x.op is Ops.RESHAPE:
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(x.shape, rngs))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
|
||||
ret:list[UOp] = []
|
||||
for s in x.src[0].shape[::-1]:
|
||||
ret.append(mish % s) # NOTE: simplify will turn this to CONST
|
||||
mish //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the view merger in RESHAPE
|
||||
rngs = list(UOp.sink(*ret[::-1]).simplify().src)
|
||||
# apply movement ops
|
||||
if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs)
|
||||
if x.op is Ops.EXPAND: ending_ranges[x] = True
|
||||
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
|
||||
@@ -2,13 +2,13 @@ from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.helpers import Metadata
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
@@ -100,69 +100,11 @@ earliest_rewrites = PatternMatcher([
|
||||
|
||||
# *****************
|
||||
# 3a. rangeify (movement)
|
||||
# NOTE: this can be deleted after the cleanup is refactored
|
||||
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
|
||||
ret:list[UOp] = []
|
||||
for s in r.src[0].shape[::-1]:
|
||||
ret.append(mish % s) # NOTE: simplify will turn this to CONST
|
||||
mish //= s
|
||||
tret = UOp.sink(*ret[::-1]).simplify().src
|
||||
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
|
||||
|
||||
def map_pad(idx:UOp, r:UOp):
|
||||
ret = list(idx.src[1:])
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if resolve(e > 0): where = where & (ret[i] < (sh-e))
|
||||
if resolve(s > 0): where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym)
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
|
||||
|
||||
def map_expand(r:UOp, idx:UOp):
|
||||
new_rngs = []
|
||||
ending_ranges = []
|
||||
non_ending_ranges = []
|
||||
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
|
||||
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
|
||||
if resolve(x==y, False):
|
||||
non_ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a)
|
||||
else:
|
||||
ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a.const_like(0))
|
||||
# if RANGEIFY >= 2, we are aggressive about not ending ranges
|
||||
if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
|
||||
# if RANGEIFY=1, if it's ending at all we end it
|
||||
else: ending_ranges = [x.arg for x in ending_ranges]
|
||||
if idx.arg is not None: ending_ranges.append(idx.arg)
|
||||
return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)
|
||||
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
pm_mops = PatternMatcher([
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
# expand needs to end ranges
|
||||
(UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand),
|
||||
# reshape does a lot of symbolic stuff
|
||||
(UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape),
|
||||
# pad adds min and max
|
||||
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user