clean up complete_create_schedule_with_vars (#14980)

* clean up complete_create_schedule_with_vars

* transform_to_call

* update viz tests
This commit is contained in:
George Hotz
2026-02-24 16:12:36 +08:00
committed by GitHub
parent 8d9545e09e
commit b643fca51e
4 changed files with 76 additions and 65 deletions

View File

@@ -282,9 +282,10 @@ class TestVizIntegration(BaseTestViz):
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
prg = get_program(ast, Device[Device.DEFAULT].renderer)
lst = get_viz_list()
self.assertEqual(len(lst), 2)
self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[1]["name"], prg.name)
self.assertEqual(len(lst), 3)
self.assertEqual(lst[0]["name"], "Process 1 Buffer n1")
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
self.assertEqual(lst[2]["name"], prg.name)
# schedule graph CALL nodes have a link to jump to codegen
def test_link_sched_codegen(self):
@@ -293,8 +294,9 @@ class TestVizIntegration(BaseTestViz):
sched = Tensor.schedule(c1, c2)
prgs = [si.lower().prg.p.name for si in sched]
lst = get_viz_list()
viz_kernel = next(i for i,s in enumerate(lst[0]["steps"]) if s["name"] == "View Kernel Graph")
graph = next(get_viz_details(0, viz_kernel))["graph"]
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
graph = next(get_viz_details(sched_idx, viz_kernel))["graph"]
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
for i,n in enumerate(call_nodes):
assert n["ref"] is not None

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, profile_matches
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize
@dataclass
class AllocCtx:
@@ -125,7 +125,7 @@ pm_replace_buf = PatternMatcher([
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
])
@profile_matches
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
# uop list is a list in the original_sink graph and we can map to the tags later
# here we build buffer map

View File

@@ -1,12 +1,11 @@
import time, inspect
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import transform_to_call
# **** schedule linearizer
@@ -59,52 +58,8 @@ def create_schedule(sched_sink:UOp) -> UOp:
if in_degree[x] == 0: queue.append(x)
return UOp(Ops.LINEAR, src=tuple(linearized))
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
schedule_cache: dict[bytes, UOp] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = transform_to_call(big_sink)
function = big_sink.src[0]
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = linear
else:
# schedule cache hit
linear = sc_ret
# it's a call that we late apply
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
# convert LINEAR to ExecItems
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
"""Convert a LINEAR UOp to a list of ExecItems."""
schedule: list[ExecItem] = []
for si in linear.src:
ast, buf_uops = si.src[0], si.src[1:]
@@ -121,17 +76,69 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
else:
schedule.append(ExecItem(ast, list(ubufs), metadata))
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
return schedule
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.uop.ops import PatternMatcher, UPat
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
schedule_cache: dict[bytes, UOp] = {}
def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
st = time.perf_counter()
function = big_sink.src[0]
if isinstance(function.arg, KernelInfo): return None
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = linear
else:
# schedule cache hit
linear = sc_ret
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
for frm in inspect.stack():
if frm.filename == "<string>": continue
if frm.filename.startswith(str(BASEDIR / "apps")): break
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
else:
frm = None
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
return graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
return buffer_map, schedule, var_vals
pm_schedule = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.SINK),), allow_any_len=True, name="big_sink"), lower_schedule_to_linear),
])
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
linear = graph_rewrite(big_sink, pm_schedule, name="schedule to linear")
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
# convert LINEAR to ExecItems
schedule: list[ExecItem] = linear_to_schedule(linear)
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
return schedule, var_vals

View File

@@ -16,6 +16,7 @@ from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_eleme
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.allocations import transform_to_call
# TODO: this should be the only usage of Device
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
@@ -255,11 +256,11 @@ class Tensor(OpMixin):
NOTE: A Tensor can only be scheduled once.
"""
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
big_sink, becomes_map = transform_to_call(UOp.sink(*[x.uop for x in (self,)+lst]))
_apply_map_to_tensors(becomes_map, name="buffers")
# this is where the schedule cache should go
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="buffers")
schedule, var_vals = complete_create_schedule_with_vars(big_sink)
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ExecItem]:
@@ -278,7 +279,8 @@ class Tensor(OpMixin):
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
big_sink, becomes_map = transform_to_call(UOp.sink(assign_uop))
schedule, var_vals = complete_create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs