mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up complete_create_schedule_with_vars (#14980)
* clean up complete_create_schedule_with_vars * transform_to_call * update viz tests
This commit is contained in:
@@ -282,9 +282,10 @@ class TestVizIntegration(BaseTestViz):
|
||||
ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast
|
||||
prg = get_program(ast, Device[Device.DEFAULT].renderer)
|
||||
lst = get_viz_list()
|
||||
self.assertEqual(len(lst), 2)
|
||||
self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1")
|
||||
self.assertEqual(lst[1]["name"], prg.name)
|
||||
self.assertEqual(len(lst), 3)
|
||||
self.assertEqual(lst[0]["name"], "Process 1 Buffer n1")
|
||||
self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1")
|
||||
self.assertEqual(lst[2]["name"], prg.name)
|
||||
|
||||
# schedule graph CALL nodes have a link to jump to codegen
|
||||
def test_link_sched_codegen(self):
|
||||
@@ -293,8 +294,9 @@ class TestVizIntegration(BaseTestViz):
|
||||
sched = Tensor.schedule(c1, c2)
|
||||
prgs = [si.lower().prg.p.name for si in sched]
|
||||
lst = get_viz_list()
|
||||
viz_kernel = next(i for i,s in enumerate(lst[0]["steps"]) if s["name"] == "View Kernel Graph")
|
||||
graph = next(get_viz_details(0, viz_kernel))["graph"]
|
||||
sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule"))
|
||||
viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph")
|
||||
graph = next(get_viz_details(sched_idx, viz_kernel))["graph"]
|
||||
call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")]
|
||||
for i,n in enumerate(call_nodes):
|
||||
assert n["ref"] is not None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, profile_matches
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize
|
||||
|
||||
@dataclass
|
||||
class AllocCtx:
|
||||
@@ -125,7 +125,7 @@ pm_replace_buf = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer),
|
||||
])
|
||||
|
||||
@profile_matches
|
||||
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
|
||||
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
# here we build buffer map
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import time, inspect
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.allocations import transform_to_call
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
@@ -59,52 +58,8 @@ def create_schedule(sched_sink:UOp) -> UOp:
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
return UOp(Ops.LINEAR, src=tuple(linearized))
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
|
||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
||||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, UOp] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
big_sink, buffer_map = transform_to_call(big_sink)
|
||||
function = big_sink.src[0]
|
||||
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
linear = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = linear
|
||||
else:
|
||||
# schedule cache hit
|
||||
linear = sc_ret
|
||||
|
||||
# it's a call that we late apply
|
||||
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
|
||||
# vars used in the schedule
|
||||
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for b in big_sink.src[1:]:
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
if nm not in used_vars: continue
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
|
||||
# convert LINEAR to ExecItems
|
||||
def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||
"""Convert a LINEAR UOp to a list of ExecItems."""
|
||||
schedule: list[ExecItem] = []
|
||||
for si in linear.src:
|
||||
ast, buf_uops = si.src[0], si.src[1:]
|
||||
@@ -121,17 +76,69 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
|
||||
else:
|
||||
schedule.append(ExecItem(ast, list(ubufs), metadata))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata))
|
||||
return schedule
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_kernel_graph
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
|
||||
def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
|
||||
return ret
|
||||
|
||||
pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]),
|
||||
# create new BUFFERs for LUNIQUE BUFFERs from rangeify
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, UOp] = {}
|
||||
def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
|
||||
st = time.perf_counter()
|
||||
function = big_sink.src[0]
|
||||
if isinstance(function.arg, KernelInfo): return None
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
linear = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = linear
|
||||
else:
|
||||
# schedule cache hit
|
||||
linear = sc_ret
|
||||
if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3:
|
||||
for frm in inspect.stack():
|
||||
if frm.filename == "<string>": continue
|
||||
if frm.filename.startswith(str(BASEDIR / "apps")): break
|
||||
if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break
|
||||
else:
|
||||
frm = None
|
||||
print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
return graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
|
||||
return buffer_map, schedule, var_vals
|
||||
pm_schedule = PatternMatcher([
|
||||
(UPat(Ops.CALL, src=(UPat(Ops.SINK),), allow_any_len=True, name="big_sink"), lower_schedule_to_linear),
|
||||
])
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
linear = graph_rewrite(big_sink, pm_schedule, name="schedule to linear")
|
||||
|
||||
# vars used in the schedule
|
||||
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
|
||||
# get var_vals
|
||||
var_vals: dict[str, int] = {}
|
||||
for b in big_sink.src[1:]:
|
||||
if b.op is Ops.BIND:
|
||||
nm = b.src[0].expr
|
||||
if nm not in used_vars: continue
|
||||
val = b.src[1].arg
|
||||
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
|
||||
var_vals[nm] = val
|
||||
|
||||
# convert LINEAR to ExecItems
|
||||
schedule: list[ExecItem] = linear_to_schedule(linear)
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
return schedule, var_vals
|
||||
|
||||
@@ -16,6 +16,7 @@ from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_eleme
|
||||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.allocations import transform_to_call
|
||||
|
||||
# TODO: this should be the only usage of Device
|
||||
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
@@ -255,11 +256,11 @@ class Tensor(OpMixin):
|
||||
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
big_sink, becomes_map = transform_to_call(UOp.sink(*[x.uop for x in (self,)+lst]))
|
||||
_apply_map_to_tensors(becomes_map, name="buffers")
|
||||
|
||||
# this is where the schedule cache should go
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="buffers")
|
||||
schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
return schedule, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||
@@ -278,7 +279,8 @@ class Tensor(OpMixin):
|
||||
# recursively realize pending assigns that this assign's value depends on
|
||||
for u in assign_uop.toposort():
|
||||
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
|
||||
big_sink, becomes_map = transform_to_call(UOp.sink(assign_uop))
|
||||
schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
||||
|
||||
Reference in New Issue
Block a user