mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
move track_rewrites to outer schedule cache (#13556)
Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -2,10 +2,10 @@ import time
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -113,6 +113,7 @@ from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, getenv
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, track_rewrites, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.uop.ops import Ops, UOp, sint, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
|
||||
from tinygrad.device import Device
|
||||
|
||||
# *** allreduce implementation ***
|
||||
@@ -218,7 +218,6 @@ multi_pm = PatternMatcher([
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
])+replace_allreduce
|
||||
|
||||
@track_rewrites()
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
|
||||
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
|
||||
|
||||
@@ -2,9 +2,9 @@ from dataclasses import dataclass, field
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -538,7 +538,6 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
|
||||
Reference in New Issue
Block a user