refactor to apply_movement_op (#12533)

* refactor to apply_movement_op

* new pm_mops is fine

* make mypy happy

* cleanup apply_movement_op function
This commit is contained in:
George Hotz
2025-10-09 10:16:09 +08:00
committed by GitHub
parent c4732a18bd
commit 615ec6acf0
2 changed files with 42 additions and 103 deletions

View File

@@ -1,11 +1,10 @@
from typing import Iterator
from typing import Iterator, Sequence
import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType
from tinygrad.uop.symbolic import sym
from tinygrad.helpers import argsort, all_same, Context
from tinygrad.uop.ops import graph_rewrite, sint, AxisType
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
@@ -43,7 +42,6 @@ class BufferizeOpts:
class IndexingContext:
realize_map: dict[UOp, None] = field(default_factory=dict)
range_map: dict[UOp, tuple[list[UOp], list[UOp]]] = field(default_factory=dict)
pads_gate: dict[UOp, UOp] = field(default_factory=dict)
# create ranges
range_idx: Iterator[int] = field(default_factory=itertools.count)
@@ -59,7 +57,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.ASSIGN and s.src[1].op is Ops.KERNEL):
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
elif s in ctx.realize_map:
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(s,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+tuple(ctx.range_map[s][1]), arg=BufferizeOpts(device=s.device), tag=s.tag)
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
new_srcs.append(new_src)
# NOTE: do we need this?
@@ -67,7 +65,8 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
def convert_pad_to_where_to_keep_behavior_local(ctx:IndexingContext, x:UOp):
if x not in ctx.range_map: return None
ret = ctx.pads_gate[x].where(x.src[0], UOp.const(x.dtype, 0))
valid: UOp = functools.reduce(operator.and_, [r.get_valid() for r in ctx.range_map[x][0]], UOp.const(dtypes.bool, True))
ret = valid.where(x.src[0], UOp.const(x.dtype, 0))
ctx.range_map[ret] = ctx.range_map[x]
return ret
@@ -103,6 +102,34 @@ pm_apply_rangeify = PatternMatcher([
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None),
])
# this is the definition of the movement ops
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
match x.op:
case Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)]
case Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
case Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
case Ops.EXPAND: rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
case Ops.PAD:
# TODO: why is multiple graph_rewrites faster than one here?
with Context(TRACK_MATCH_STATS=0):
rngs = [r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh-e))).where(r-s, UOp.invalid()), sym)
for r,sh,(s,e) in zip(rngs, x.shape, x.arg)]
case Ops.RESHAPE:
acc = 1
axes_in:list[UOp] = []
for s,src in list(zip(x.shape, rngs))[::-1]:
axes_in.append(acc*src)
acc *= s
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
axes_out:list[UOp] = []
for s in x.src[0].shape[::-1]:
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
rngs = list(UOp.sink(*axes_out[::-1]).simplify().src)
case _: raise RuntimeError(f"{x.op} is not a MovementOp")
return rngs
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rctx = IndexingContext()
@@ -177,39 +204,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
rngs = out_rngs # rngs is the input ranges
# apply movement ops. this is the definition of them
if x.op is Ops.SHRINK: rngs = [a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(rngs, x.arg)]
if x.op is Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
if x.op is Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
if x.op is Ops.EXPAND:
rngs = [a.const_like(0) if resolve(in_sh!=out_sh) else a for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
ending_ranges[x] = True
if x.op is Ops.PAD:
rngs = rngs[:]
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(x.shape, x.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if resolve(e > 0): where = where & (rngs[i] < (sh-e))
if resolve(s > 0): where = where & (rngs[i] >= s)
bigwhere = bigwhere & where
with Context(TRACK_MATCH_STATS=0):
rngs[i] = graph_rewrite(where.where(rngs[i]-s, UOp.invalid()), sym)
# PAD is replaced with a WHERE in the big graph to inject the 0s at the right place
rctx.pads_gate[x] = bigwhere.simplify()
if x.op is Ops.RESHAPE:
acc = 1
to_sum = []
for s,src in list(zip(x.shape, rngs))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
ret:list[UOp] = []
for s in x.src[0].shape[::-1]:
ret.append(mish % s) # NOTE: simplify will turn this to CONST
mish //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the view merger in RESHAPE
rngs = list(UOp.sink(*ret[::-1]).simplify().src)
# apply movement ops
if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs)
if x.op is Ops.EXPAND: ending_ranges[x] = True
# REDUCE_AXIS creates ranges for the axes it is reducing
if x.op is Ops.REDUCE_AXIS:

View File

@@ -2,13 +2,13 @@ from typing import cast
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
from tinygrad.helpers import Metadata
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
# creation can recurse a lot
import sys
@@ -100,69 +100,11 @@ earliest_rewrites = PatternMatcher([
# *****************
# 3a. rangeify (movement)
# NOTE: this can be deleted after the cleanup is refactored
def map_reshape(idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum, start=UOp.const(dtypes.index, 0))
ret:list[UOp] = []
for s in r.src[0].shape[::-1]:
ret.append(mish % s) # NOTE: simplify will turn this to CONST
mish //= s
tret = UOp.sink(*ret[::-1]).simplify().src
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
ret = list(idx.src[1:])
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if resolve(e > 0): where = where & (ret[i] < (sh-e))
if resolve(s > 0): where = where & (ret[i] >= s)
bigwhere = bigwhere & where
with Context(TRACK_MATCH_STATS=0):
ret[i] = graph_rewrite(where.where(ret[i]-s, UOp.invalid()), sym)
# PAD is with 0
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
def map_expand(r:UOp, idx:UOp):
new_rngs = []
ending_ranges = []
non_ending_ranges = []
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
if resolve(x==y, False):
non_ending_ranges.extend(axis_to_range)
new_rngs.append(a)
else:
ending_ranges.extend(axis_to_range)
new_rngs.append(a.const_like(0))
# if RANGEIFY >= 2, we are aggressive about not ending ranges
if RANGEIFY >= 2: ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
# if RANGEIFY=1, if it's ending at all we end it
else: ending_ranges = [x.arg for x in ending_ranges]
if idx.arg is not None: ending_ranges.append(idx.arg)
return r.src[0].index(*new_rngs, arg=min(ending_ranges) if ending_ranges else None)
# movement op on INDEX as a PatternMatcher
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
# expand needs to end ranges
(UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape),
# pad adds min and max
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
])
# *****************