From bd6a068ef75002ee8b58134e0972bc15dc6b6f6b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 4 Dec 2025 03:13:45 -0800 Subject: [PATCH] move track_rewrites to outer schedule cache (#13556) Co-authored-by: qazal --- tinygrad/engine/schedule.py | 5 +++-- tinygrad/schedule/multi.py | 3 +-- tinygrad/schedule/rangeify.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index dbde96a992..362073e721 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,10 +2,10 @@ import time from typing import cast from dataclasses import dataclass, field, replace from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten +from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize # **** ScheduleItem return type @@ -113,6 +113,7 @@ from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.multi import get_multi_map +@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]: # big_sink srcs are all the Tensors st = time.perf_counter() diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index a665bca837..fe769996bc 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -1,7 +1,7 @@ from typing import cast import functools, itertools, operator from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv -from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map, graph_rewrite +from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite from tinygrad.device import Device # *** allreduce implementation *** @@ -218,7 +218,6 @@ multi_pm = PatternMatcher([ src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), ])+replace_allreduce -@track_rewrites() def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST") ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm") diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 07a1400478..d1903aa2a4 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo -from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str +from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.symbolic import symbolic -from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY +from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt @@ -538,7 +538,6 @@ replace_contiguous = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}") def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]: if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph") uop_list: list[UOp] = []