mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cache apply_movement_op (#12609)
* cache apply_movement_op * pyling and clear cache * fix types * ignore * cleanup
This commit is contained in:
2
test/external/external_uop_gc.py
vendored
2
test/external/external_uop_gc.py
vendored
@@ -2,6 +2,7 @@ import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.shape.shapetracker import views_to_valid_uop
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op
|
||||
from test.test_tiny import TestTiny
|
||||
|
||||
def uops_allocated(): return sum([isinstance(x, UOp) for x in gc.get_objects()])
|
||||
@@ -69,6 +70,7 @@ if __name__ == "__main__":
|
||||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
views_to_valid_uop.cache_clear()
|
||||
apply_movement_op.cache_clear()
|
||||
Tensor._device_seeds.clear()
|
||||
Tensor._device_rng_counters.clear()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Iterator, Sequence
|
||||
from typing import Iterator
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
@@ -41,7 +41,7 @@ class BufferizeOpts:
|
||||
@dataclass
|
||||
class IndexingContext:
|
||||
realize_map: dict[UOp, None] = field(default_factory=dict)
|
||||
range_map: dict[UOp, tuple[list[UOp], list[UOp]]] = field(default_factory=dict)
|
||||
range_map: dict[UOp, tuple[tuple[UOp, ...], tuple[UOp, ...]]] = field(default_factory=dict)
|
||||
|
||||
# create ranges
|
||||
range_idx: Iterator[int] = field(default_factory=itertools.count)
|
||||
@@ -103,30 +103,31 @@ pm_apply_rangeify = PatternMatcher([
|
||||
])
|
||||
|
||||
# this is the definition of the movement ops
|
||||
def apply_movement_op(x:UOp, rngs:Sequence[UOp]) -> list[UOp]:
|
||||
match x.op:
|
||||
case Ops.SHRINK: rngs = [a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, x.arg)]
|
||||
case Ops.PERMUTE: rngs = [rngs[p] for p in argsort(x.arg)]
|
||||
case Ops.FLIP: rngs = [((s-1)-a) if f else a for a,s,f in zip(rngs, x.shape, x.arg)]
|
||||
case Ops.EXPAND: rngs = [a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, x.src[0].shape, x.shape)]
|
||||
@functools.cache
|
||||
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
match op:
|
||||
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
|
||||
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
|
||||
case Ops.FLIP: rngs = tuple(((s-1)-a) if f else a for a,s,f in zip(rngs, in_shape, arg))
|
||||
case Ops.EXPAND: rngs = tuple(a if in_sh == out_sh else a.const_like(0) for a,in_sh,out_sh in zip(rngs, in_shape, arg))
|
||||
case Ops.PAD:
|
||||
# TODO: why is multiple graph_rewrites faster than one here?
|
||||
rngs = [r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh-e))).where(r-s, UOp.invalid()), sym, name="pad")
|
||||
for r,sh,(s,e) in zip(rngs, x.shape, x.arg)]
|
||||
rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))).where(r-s, UOp.invalid()), sym, name="pad")
|
||||
for r,sh,(s,e) in zip(rngs, in_shape, arg))
|
||||
case Ops.RESHAPE:
|
||||
acc = 1
|
||||
axes_in:list[UOp] = []
|
||||
for s,src in list(zip(x.shape, rngs))[::-1]:
|
||||
for s,src in list(zip(arg, rngs))[::-1]:
|
||||
axes_in.append(acc*src)
|
||||
acc *= s
|
||||
combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0))
|
||||
axes_out:list[UOp] = []
|
||||
for s in x.src[0].shape[::-1]:
|
||||
for s in in_shape[::-1]:
|
||||
axes_out.append(combined_axes % s)
|
||||
combined_axes //= s
|
||||
# this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code
|
||||
rngs = list(graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src)
|
||||
case _: raise RuntimeError(f"{x.op} is not a MovementOp")
|
||||
rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic, name="reshape").src
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
@cpu_profile(TracingKey("run_rangeify"), "TINY")
|
||||
@@ -157,7 +158,7 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
|
||||
if x in rctx.realize_map:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
out_rngs = [rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape]
|
||||
out_rngs = tuple(rctx.new_range(s) if not isinstance(s, UOp) or s.op is not Ops.RANGE else s for s in x.shape)
|
||||
# all ranges are ended now
|
||||
ending_ranges[x] = False
|
||||
elif x.op in {Ops.MSTACK, Ops.MSELECT}:
|
||||
@@ -181,15 +182,16 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
|
||||
# TODO: in RANGEIFY > 1 all_all_same isn't required
|
||||
all_all_same = all(same_rngs for _,_,same_rngs in rngs_valids)
|
||||
out_rngs = []
|
||||
_out_rngs = []
|
||||
for i,(local_rngs,valids,same_rngs) in enumerate(rngs_valids):
|
||||
# we compare the ranges without their valids
|
||||
if all_all_same:
|
||||
# the new valid is the OR of all the children valids
|
||||
minimum_valid = functools.reduce(operator.or_, valids, UOp.const(dtypes.bool, False))
|
||||
out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
|
||||
_out_rngs.append(graph_rewrite(minimum_valid.where(local_rngs[0], UOp.invalid()), symbolic, name="minimum_valid"))
|
||||
else:
|
||||
out_rngs.append(rctx.new_range(x.shape[i]))
|
||||
_out_rngs.append(rctx.new_range(x.shape[i]))
|
||||
out_rngs = tuple(_out_rngs)
|
||||
|
||||
# we have to realize here if there's new ranges
|
||||
if not all_all_same: rctx.realize_map[x] = None
|
||||
@@ -203,18 +205,16 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# 2. newly created for REDUCE_AXIS
|
||||
# 3. passed through for everything else
|
||||
|
||||
rngs = out_rngs # rngs is the input ranges
|
||||
rngs = out_rngs # rngs is the input ranges # pylint: disable=possibly-used-before-assignment
|
||||
|
||||
# apply movement ops
|
||||
if x.op in GroupOp.Movement: rngs = apply_movement_op(x, rngs)
|
||||
if x.op in GroupOp.Movement: rngs = apply_movement_op(x.op, x.src[0].shape, x.arg, rngs)
|
||||
# if the EXPAND is used to inject a range, we don't mark it as ending_ranges. otherwise we do.
|
||||
if x.op is Ops.EXPAND and all(isinstance(y, int) or y.op is not Ops.RANGE for y in x.shape): ending_ranges[x] = True
|
||||
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
rngs = rngs[:]
|
||||
for i,s in enumerate(x.src[0].shape):
|
||||
if i in x.arg[1]: rngs[i] = rctx.new_range(s, axistype=AxisType.REDUCE)
|
||||
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||
|
||||
if debug:
|
||||
print("***" if x in rctx.realize_map else " ", len(consumer_map[x]), f"{str(x.op):20s}",
|
||||
|
||||
@@ -103,7 +103,7 @@ earliest_rewrites = PatternMatcher([
|
||||
# movement op on INDEX as a PatternMatcher
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.arg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user