diff --git a/test/external/external_uop_gc.py b/test/external/external_uop_gc.py index f54c24e7c5..6399b2df19 100644 --- a/test/external/external_uop_gc.py +++ b/test/external/external_uop_gc.py @@ -1,7 +1,7 @@ import gc from tinygrad import Tensor, UOp, Device, nn from tinygrad.engine.realize import method_cache, get_program -from tinygrad.schedule.indexing import apply_movement_op +from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape from tinygrad.uop.divandmod import fold_divmod_general from test.test_tiny import TestTiny @@ -70,6 +70,7 @@ if __name__ == "__main__": # these caches will keep uops alive method_cache.clear() apply_movement_op.cache_clear() + _apply_reshape.cache_clear() fold_divmod_general.cache_clear() Tensor._device_seeds.clear() Tensor._device_rng_counters.clear() diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 770eb280ca..58117ee1b4 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -3,7 +3,7 @@ import time, pprint, random, itertools, math from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU, getenv, cpu_profile, PROFILE, ProfilePointEvent, cpu_events, prod, Context -from tinygrad.helpers import unwrap, disable_gc +from tinygrad.helpers import unwrap from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates @@ -13,7 +13,6 @@ from tinygrad.codegen.opt import Opt # **************** Program Creation **************** -@disable_gc() @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec: """ diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5b3d891123..dbde96a992 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,10 +2,10 @@ import time from typing import cast from dataclasses import dataclass, field, replace from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, disable_gc +from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten # **** ScheduleItem return type @@ -113,7 +113,6 @@ from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.multi import get_multi_map -@disable_gc() def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]: # big_sink srcs are all the Tensors st = time.perf_counter() @@ -139,5 +138,6 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li # remove all AFTERs, after scheduling, the tensors are just buffers tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER} - if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms") + if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: + print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)") return tensor_map, schedule, var_vals diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 7d1a90a118..5f1b140fb8 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -119,6 +119,21 @@ pm_apply_rangeify = PatternMatcher([ (UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda ctx,c: c.replace(src=()) if c in ctx.range_map else None), ]) +@functools.cache +def _apply_reshape(in_shape:tuple[sint,...], out_shape:tuple[sint, ...], urngs:UOp) -> UOp: + acc = 1 + axes_in:list[UOp] = [] + for s,src in list(zip(out_shape, urngs.src))[::-1]: + axes_in.append(acc*src) + acc *= s + combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0)) + axes_out:list[UOp] = [] + for s in in_shape[::-1]: + axes_out.append(combined_axes % s) + combined_axes //= s + # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code + return graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape") + # this is the definition of the movement ops @functools.cache def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UOp, ...]) -> tuple[UOp, ...]: @@ -134,18 +149,9 @@ def apply_movement_op(op:Ops, in_shape:tuple[sint,...], arg:tuple, rngs:tuple[UO rngs = tuple(r if (s == 0 and e == 0) else graph_rewrite(((r >= s) & (r < (sh+s))), symbolic+pm_simplify_valid, name="pad").where(r-s, UOp.invalid()) for r,sh,(s,e) in zip(rngs, in_shape, arg)) case Ops.RESHAPE: - acc = 1 - axes_in:list[UOp] = [] - for s,src in list(zip(arg, rngs))[::-1]: - axes_in.append(acc*src) - acc *= s - combined_axes = sum(axes_in, start=UOp.const(dtypes.index, 0)) - axes_out:list[UOp] = [] - for s in in_shape[::-1]: - axes_out.append(combined_axes % s) - combined_axes //= s - # this simplify is doing a lot of heavy lifting. this is the replacement for the reshape view merging code - rngs = graph_rewrite(UOp.sink(*axes_out[::-1]), symbolic+pm_simplify_valid+pm_drop_and_clauses, name="reshape").src + sink = UOp.sink(*rngs) + sub_array = {r:UOp.range(r.src[0], i, AxisType.PLACEHOLDER) for i,r in enumerate(sink.ranges)} + rngs = _apply_reshape(in_shape, arg, sink.substitute(sub_array)).substitute({v:k for k,v in sub_array.items()}).src case _: raise RuntimeError(f"{op} is not a MovementOp") return rngs diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8659727ec3..6c8fec27dd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.dtype import _from_np_dtype, _to_np_dtype from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile -from tinygrad.helpers import suppress_finalizing +from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin from tinygrad.mixin.movement import _align_left @@ -241,6 +241,7 @@ class Tensor(OpMixin): assert len(var_vals) == 0 return schedule + @disable_gc() def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: """Triggers the computation needed to create these Tensor(s).""" if len(to_realize:=[x for x in (self,)+lst if not x.uop.is_contiguous()]): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 836a355a6d..7041bb7e3b 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class AxisType(Enum): def __repr__(self): return str(self) GLOBAL = auto(); WARP = auto(); LOCAL = auto(); LOOP = auto(); GROUP_REDUCE = auto(); REDUCE = auto(); UPCAST = auto(); UNROLL = auto() # noqa: E702 - THREAD = auto(); OUTER = auto() # noqa: E702 + THREAD = auto(); OUTER = auto(); PLACEHOLDER = auto() # noqa: E702 axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l", AxisType.WARP: "w", AxisType.LOOP: "L", AxisType.UPCAST: "u", AxisType.GROUP_REDUCE: "G", AxisType.REDUCE: "R", AxisType.UNROLL: "r", AxisType.OUTER: "O"} axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",