move track_rewrites to outer schedule cache (#13556)

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2025-12-04 03:13:45 -08:00
committed by GitHub
parent 3eae146139
commit bd6a068ef7
3 changed files with 6 additions and 7 deletions

View File

@@ -2,10 +2,10 @@ import time
from typing import cast
from dataclasses import dataclass, field, replace
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
# **** ScheduleItem return type
@@ -113,6 +113,7 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()

View File

@@ -1,7 +1,7 @@
from typing import cast
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map, graph_rewrite
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.device import Device
# *** allreduce implementation ***
@@ -218,7 +218,6 @@ multi_pm = PatternMatcher([
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])+replace_allreduce
@track_rewrites()
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")

View File

@@ -2,9 +2,9 @@ from dataclasses import dataclass, field
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
@@ -538,7 +538,6 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}")
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []