cache apply_movement_op (#12609)

* cache apply_movement_op

* pyling and clear cache

* fix types

* ignore

* cleanup
This commit is contained in:
Sieds Lykles
2025-10-11 08:53:10 +02:00
committed by GitHub
parent 7596c1b8f5
commit 4300ebc455
3 changed files with 26 additions and 24 deletions

View File

@@ -2,6 +2,7 @@ import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.shape.shapetracker import views_to_valid_uop
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op
from test.test_tiny import TestTiny
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
@@ -69,6 +70,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
views_to_valid_uop.cache_clear()
apply_movement_op.cache_clear()
Tensor._device_seeds.clear()
Tensor._device_rng_counters.clear()

View File

@@ -1,4 +1,4 @@
from typing import Iterator, Sequence
from typing import Iterator
import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace
@@ -41,7 +41,7 @@ class BufferizeOpts:
@dataclass
class IndexingContext:
realize_map: dict[UOp, None] = field(default_factory=dict)
range_map: dict[UOp, tuple[list[UOp], list[UOp]]] = field(default_factory=dict)
range_map: dict[UOp, tuple[tuple[UOp, ...], tuple[UOp, ...]]] = field(default_factory=dict)
# create ranges
range_idx: Iterator[int] = field(default_factory=itertools.count)
@@ -103,30 +103,31 @@ pm_apply_rangeify = PatternMatcher([
])
# this is the definition of the movement ops
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
match x.op:
case Ops.SHRINK: rngs = [a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, x.arg)]
case Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
case Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
case Ops.EXPAND: rngs = [a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
@functools.cache
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
match op:
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
case Ops.PAD:
# TODO: why is multiple graph_rewrites faster than one here?
rngs = [r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh-e))).where(r-s, UOp.invalid()), sym, name="pad")
for r,sh,(s,e) in zip(rngs, x.shape, x.arg)]
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
for r,sh,(s,e) in zip(rngs, in_shape, arg))
case Ops.RESHAPE:
acc = 1
axes_in:list[UOp] = []
for s,src in list(zip(x.shape, rngs))[::-1]:
for s,src in list(zip(arg, rngs))[::-1]:
axes_in.append(acc*src)
acc *= s
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
axes_out:list[UOp] = []
for s in x.src[0].shape[::-1]:
for s in in_shape[::-1]:
axes_out.append(combined_axes % s)
combined_axes //= s
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
rngs = list(graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src)
case _: raise RuntimeError(f"{x.op} is not a MovementOp")
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs
@cpu_profile(TracingKey("run_rangeify"), "TINY")
@@ -157,7 +158,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
if x in rctx.realize_map:
# if this is in the realize_map, we create new ranges (at the output)
out_rngs = [rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape]
out_rngs = tuple(rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape)
# all ranges are ended now
ending_ranges[x] = False
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
@@ -181,15 +182,16 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# TODO: in RANGEIFY > 1 all_all_same isn't required
all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids)
out_rngs = []
_out_rngs = []
for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids):
# we compare the ranges without their valids
if all_all_same:
# the new valid is the OR of all the children valids
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
else:
out_rngs.append(rctx.new_range(x.shape[i]))
_out_rngs.append(rctx.new_range(x.shape[i]))
out_rngs = tuple(_out_rngs)
# we have to realize here if there's new ranges
if not all_all_same: rctx.realize_map[x] = None
@@ -203,18 +205,16 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# 2. newly created for REDUCE_AXIS
# 3. passed through for everything else
rngs = out_rngs # rngs is the input ranges
rngs = out_rngs # rngs is the input ranges # pylint: disable=possibly-used-before-assignment
# apply movement ops
if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs)
if x.op in GroupOp.Movement: rngs = apply_movement_op(x.op, x.src[0].shape, x.arg, rngs)
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True
# REDUCE_AXIS creates ranges for the axes it is reducing
if x.op is Ops.REDUCE_AXIS:
rngs = rngs[:]
for i,s in enumerate(x.src[0].shape):
if i in x.arg[1]: rngs[i] = rctx.new_range(s, axistype=AxisType.REDUCE)
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
if debug:
print("***" if x in rctx.realize_map else " ", len(consumer_map[x]), f"{str(x.op):20s}",

View File

@@ -103,7 +103,7 @@ earliest_rewrites = PatternMatcher([
# movement op on INDEX as a PatternMatcher
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.arg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
])
# *****************