diff --git a/test/null/test_viz.py b/test/null/test_viz.py index 4ec9468d69..8d8c74e6d9 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -282,9 +282,10 @@ class TestVizIntegration(BaseTestViz): ast = Tensor.schedule(Tensor.empty(4)+Tensor.empty(4))[0].ast prg = get_program(ast, Device[Device.DEFAULT].renderer) lst = get_viz_list() - self.assertEqual(len(lst), 2) - self.assertEqual(lst[0]["name"], "Schedule 1 Kernel n1") - self.assertEqual(lst[1]["name"], prg.name) + self.assertEqual(len(lst), 3) + self.assertEqual(lst[0]["name"], "Process 1 Buffer n1") + self.assertEqual(lst[1]["name"], "Schedule 1 Kernel n1") + self.assertEqual(lst[2]["name"], prg.name) # schedule graph CALL nodes have a link to jump to codegen def test_link_sched_codegen(self): @@ -293,8 +294,9 @@ class TestVizIntegration(BaseTestViz): sched = Tensor.schedule(c1, c2) prgs = [si.lower().prg.p.name for si in sched] lst = get_viz_list() - viz_kernel = next(i for i,s in enumerate(lst[0]["steps"]) if s["name"] == "View Kernel Graph") - graph = next(get_viz_details(0, viz_kernel))["graph"] + sched_idx = next(i for i,l in enumerate(lst) if l["name"].startswith("Schedule")) + viz_kernel = next(i for i,s in enumerate(lst[sched_idx]["steps"]) if s["name"] == "View Kernel Graph") + graph = next(get_viz_details(sched_idx, viz_kernel))["graph"] call_nodes = [n for n in graph.values() if n["label"].startswith("CALL")] for i,n in enumerate(call_nodes): assert n["ref"] is not None diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 8508528e12..971aab5d21 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, profile_matches +from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites from tinygrad.dtype import ImageDType -from tinygrad.helpers import prod, DEBUG, argsort, VIZ +from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize @dataclass class AllocCtx: @@ -125,7 +125,7 @@ pm_replace_buf = PatternMatcher([ (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), replace_input_buffer), ]) -@profile_matches +@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}") def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]: # uop list is a list in the original_sink graph and we can map to the tags later # here we build buffer map diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7fd4e361a5..ff2f1bc7db 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,12 +1,11 @@ import time, inspect from typing import cast from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR from tinygrad.engine.realize import ExecItem -from tinygrad.engine.allocations import transform_to_call # **** schedule linearizer @@ -59,52 +58,8 @@ def create_schedule(sched_sink:UOp) -> UOp: if in_degree[x] == 0: queue.append(x) return UOp(Ops.LINEAR, src=tuple(linearized)) -from tinygrad.engine.memory import memory_planner -from tinygrad.schedule.rangeify import get_kernel_graph -from tinygrad.uop.ops import PatternMatcher, UPat - -def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp): - if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype) - return ret - -pm_post_sched_cache = PatternMatcher([ - (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]), - # create new BUFFERs for LUNIQUE BUFFERs from rangeify - (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), -]) - -schedule_cache: dict[bytes, UOp] = {} -@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") -def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]: - # big_sink srcs are all the Tensors - st = time.perf_counter() - big_sink, buffer_map = transform_to_call(big_sink) - function = big_sink.src[0] - - if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None: - if SPEC: type_verify(big_sink, tensor_spec) - linear = create_schedule(get_kernel_graph(function)) - if SCACHE: schedule_cache[function.key] = linear - else: - # schedule cache hit - linear = sc_ret - - # it's a call that we late apply - linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") - - # vars used in the schedule - used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src]) - # get var_vals - var_vals: dict[str, int] = {} - for b in big_sink.src[1:]: - if b.op is Ops.BIND: - nm = b.src[0].expr - if nm not in used_vars: continue - val = b.src[1].arg - assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" - var_vals[nm] = val - - # convert LINEAR to ExecItems +def linear_to_schedule(linear:UOp) -> list[ExecItem]: + """Convert a LINEAR UOp to a list of ExecItems.""" schedule: list[ExecItem] = [] for si in linear.src: ast, buf_uops = si.src[0], si.src[1:] @@ -121,17 +76,69 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {})) else: - schedule.append(ExecItem(ast, list(ubufs), metadata)) - with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) + schedule.append(ExecItem(ast, cast(list[Buffer|None], ubufs), metadata)) + return schedule - if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: +from tinygrad.engine.memory import memory_planner +from tinygrad.schedule.rangeify import get_kernel_graph +from tinygrad.uop.ops import PatternMatcher, UPat + +def create_new_buffer(ctx:tuple[dict[UOp, UOp], tuple[UOp, ...]], b:UOp): + if (ret:=ctx[0].get(b, None)) is None: ctx[0][b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype) + return ret + +pm_post_sched_cache = PatternMatcher([ + (UPat(Ops.PARAM, name="x"), lambda ctx,x: ctx[1][x.arg]), + # create new BUFFERs for LUNIQUE BUFFERs from rangeify + (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), +]) + +schedule_cache: dict[bytes, UOp] = {} +def lower_schedule_to_linear(big_sink:UOp) -> UOp|None: + st = time.perf_counter() + function = big_sink.src[0] + if isinstance(function.arg, KernelInfo): return None + if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None: + if SPEC: type_verify(big_sink, tensor_spec) + linear = create_schedule(get_kernel_graph(function)) + if SCACHE: schedule_cache[function.key] = linear + else: + # schedule cache hit + linear = sc_ret + if (DEBUG >= 1 and len(linear.src) > 1) or DEBUG >= 3: for frm in inspect.stack(): + if frm.filename == "": continue if frm.filename.startswith(str(BASEDIR / "apps")): break if not frm.filename.startswith(str(BASEDIR)) and not frm.filename.endswith("/contextlib.py"): break else: frm = None - print(f"scheduled {len(schedule):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ + print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}")) + return graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") - return buffer_map, schedule, var_vals +pm_schedule = PatternMatcher([ + (UPat(Ops.CALL, src=(UPat(Ops.SINK),), allow_any_len=True, name="big_sink"), lower_schedule_to_linear), +]) + +@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0]))}") +def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[list[ExecItem], dict[str, int]]: + # big_sink srcs are all the Tensors + linear = graph_rewrite(big_sink, pm_schedule, name="schedule to linear") + + # vars used in the schedule + used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src]) + # get var_vals + var_vals: dict[str, int] = {} + for b in big_sink.src[1:]: + if b.op is Ops.BIND: + nm = b.src[0].expr + if nm not in used_vars: continue + val = b.src[1].arg + assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" + var_vals[nm] = val + + # convert LINEAR to ExecItems + schedule: list[ExecItem] = linear_to_schedule(linear) + with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) + return schedule, var_vals diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 394bdc61ad..de4d8c4355 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -16,6 +16,7 @@ from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_eleme from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule +from tinygrad.engine.allocations import transform_to_call # TODO: this should be the only usage of Device def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: @@ -255,11 +256,11 @@ class Tensor(OpMixin): NOTE: A Tensor can only be scheduled once. """ - big_sink = UOp.sink(*[x.uop for x in (self,)+lst]) + big_sink, becomes_map = transform_to_call(UOp.sink(*[x.uop for x in (self,)+lst])) + _apply_map_to_tensors(becomes_map, name="buffers") # this is where the schedule cache should go - becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink) - _apply_map_to_tensors(becomes_map, name="buffers") + schedule, var_vals = complete_create_schedule_with_vars(big_sink) return schedule, var_vals def schedule(self, *lst:Tensor) -> list[ExecItem]: @@ -278,7 +279,8 @@ class Tensor(OpMixin): # recursively realize pending assigns that this assign's value depends on for u in assign_uop.toposort(): if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u) - becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop)) + big_sink, becomes_map = transform_to_call(UOp.sink(assign_uop)) + schedule, var_vals = complete_create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map, name="Apply Pending Assign") run_schedule(schedule, var_vals, do_update_stats=do_update_stats) # update remaining pending assigns so they reference realized buffers instead of stale lazy graphs