mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
AxisType.PLACEHOLDER in reshape to do less graph_rewrite (#13373)
* AxisType.PLACEHOLDER in reshape to do less graph_rewrite * _apply_movement_op cache
This commit is contained in:
4
test/external/external_uop_gc.py
vendored
4
test/external/external_uop_gc.py
vendored
@@ -1,7 +1,7 @@
|
||||
import gc
|
||||
from tinygrad import Tensor, UOp, Device, nn
|
||||
from tinygrad.engine.realize import method_cache, get_program
|
||||
from tinygrad.schedule.indexing import apply_movement_op
|
||||
from tinygrad.schedule.indexing import _apply_movement_op
|
||||
from tinygrad.uop.divandmod import fold_divmod_general
|
||||
from test.test_tiny import TestTiny
|
||||
|
||||
@@ -69,7 +69,7 @@ if __name__ == "__main__":
|
||||
|
||||
# these caches will keep uops alive
|
||||
method_cache.clear()
|
||||
apply_movement_op.cache_clear()
|
||||
_apply_movement_op.cache_clear()
|
||||
fold_divmod_general.cache_clear()
|
||||
Tensor._device_seeds.clear()
|
||||
Tensor._device_rng_counters.clear()
|
||||
|
||||
@@ -117,7 +117,7 @@ pm_apply_rangeify = PatternMatcher([
|
||||
|
||||
# this is the definition of the movement ops
|
||||
@functools.cache
|
||||
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
def _apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
match op:
|
||||
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
|
||||
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
|
||||
@@ -145,6 +145,13 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
|
||||
case _: raise RuntimeError(f"{op} is not a MovementOp")
|
||||
return rngs
|
||||
|
||||
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
|
||||
# for PAD and RESHAPE, we replace the ranges with PLACEHOLDERS
|
||||
if op not in (Ops.PAD, Ops.RESHAPE): return _apply_movement_op(op, in_shape, arg, rngs)
|
||||
sink = UOp.sink(*rngs)
|
||||
real_ranges = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}
|
||||
return UOp.sink(*_apply_movement_op(op, in_shape, arg, sink.substitute(real_ranges).src)).substitute({v:k for k,v in real_ranges.items()}).src
|
||||
|
||||
@profile_matches
|
||||
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
if debug: print("**************************")
|
||||
|
||||
@@ -14,12 +14,12 @@ if TYPE_CHECKING:
|
||||
class AxisType(Enum):
|
||||
def __repr__(self): return str(self)
|
||||
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto() # noqa: E702
|
||||
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
|
||||
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
|
||||
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O", AxisType.PLACEHOLDER: "P"}
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
|
||||
AxisType.OUTER: "green"}
|
||||
AxisType.OUTER: "green", AxisType.PLACEHOLDER: "white"}
|
||||
|
||||
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
|
||||
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
|
||||
@@ -615,6 +615,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
|
||||
def buf_target(self) -> UOp:
|
||||
# the buffer that's being loaded from or store to
|
||||
# NOTE: this is the good one to keep
|
||||
match self.op:
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
|
||||
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
|
||||
|
||||
Reference in New Issue
Block a user