AxisType.PLACEHOLDER in reshape to do less graph_rewrite (#13373)

* AxisType.PLACEHOLDER in reshape to do less graph_rewrite

* _apply_movement_op cache
This commit is contained in:
George Hotz
2025-11-19 19:19:58 -08:00
committed by GitHub
parent 050682ab40
commit ac7559e33d
3 changed files with 14 additions and 6 deletions

View File

@@ -1,7 +1,7 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op
from tinygrad.schedule.indexing import _apply_movement_op
from tinygrad.uop.divandmod import fold_divmod_general
from test.test_tiny import TestTiny
@@ -69,7 +69,7 @@ if __name__ == "__main__":
# these caches will keep uops alive
method_cache.clear()
apply_movement_op.cache_clear()
_apply_movement_op.cache_clear()
fold_divmod_general.cache_clear()
Tensor._device_seeds.clear()
Tensor._device_rng_counters.clear()

View File

@@ -117,7 +117,7 @@ pm_apply_rangeify = PatternMatcher([
# this is the definition of the movement ops
@functools.cache
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
def _apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
match op:
case Ops.SHRINK: rngs = tuple(a if ss == 0 else a+ss for a,(ss,_) in zip(rngs, arg))
case Ops.PERMUTE: rngs = tuple(rngs[p] for p in argsort(arg))
@@ -145,6 +145,13 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO
case _: raise RuntimeError(f"{op} is not a MovementOp")
return rngs
def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]:
# for PAD and RESHAPE, we replace the ranges with PLACEHOLDERS
if op not in (Ops.PAD, Ops.RESHAPE): return _apply_movement_op(op, in_shape, arg, rngs)
sink = UOp.sink(*rngs)
real_ranges = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)}
return UOp.sink(*_apply_movement_op(op, in_shape, arg, sink.substitute(real_ranges).src)).substitute({v:k for k,v in real_ranges.items()}).src
@profile_matches
def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if debug: print("**************************")

View File

@@ -14,12 +14,12 @@ if TYPE_CHECKING:
class AxisType(Enum):
def __repr__(self): return str(self)
GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702
THREAD = auto(); OUTER = auto() # noqa: E702
THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702
axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u",
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"}
AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O", AxisType.PLACEHOLDER: "P"}
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta",
AxisType.OUTER: "green"}
AxisType.OUTER: "green", AxisType.PLACEHOLDER: "white"}
# NOTE: LOCAL and GROUP_REDUCE have the same priority. the order here matters
axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisType.WARP: 1, AxisType.LOCAL: 2, AxisType.UPCAST: 3,
@@ -615,6 +615,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def buf_target(self) -> UOp:
# the buffer that's being loaded from or store to
# NOTE: this is the good one to keep
match self.op:
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()