diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 339bc0103d..b9c47bbad2 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,14 +1,17 @@ import time from typing import cast from collections import deque -from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map +from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata from tinygrad.engine.realize import ExecItem # **** schedule linearizer +# ScheduleItem = tuple[AST, buffer UOps, metadata, fixedvars, bound_ranges] +ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], dict[str, int], tuple[UOp, ...]] + # unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op) def _unwrap_src(s: UOp) -> UOp: while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] @@ -48,48 +51,48 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: for k,v in in_degree.items(): if v == 0: queue.append(k) - schedule: list[tuple|UOp] = [] + schedule: list[ScheduleItem|UOp] = [] # ScheduleItem for kernels, UOp for RANGE/END while len(queue): k = rk = queue.popleft() if k.op is Ops.END: k = k.src[0] + assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}" if k.op is Ops.RANGE: schedule.append(k) elif k.op is Ops.KERNEL: - ast = k.arg.ast + ast = (kernel:=cast(Kernel, k.arg)).ast buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND) bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE) - schedule.append((ast, buf_uops, k.arg.metadata, {}, bound_ranges)) + schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges)) if rk.op is Ops.END: schedule.append(rk) - else: - raise RuntimeError(f"can't schedule {k.op}") for x in children.get(rk, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) - with cpu_profile(TracingKey("expand ranges")): - pre_schedule: list[ExecItem] = [] - buf_uops_list: list[UOp] = [] - sched_ptr = 0 - in_ranges: dict[UOp, int] = {} - range_ptrs: dict[UOp, int] = {} - while sched_ptr < len(schedule): - si = schedule[sched_ptr] - if isinstance(si, UOp): - if si.op is Ops.RANGE: - in_ranges[si] = 0 - range_ptrs[si] = sched_ptr + 1 - elif si.op is Ops.END: - if in_ranges[si.src[1]] < si.src[1].vmax: - in_ranges[si.src[1]] += 1 - sched_ptr = range_ptrs[si.src[1]] - continue - else: - ast, buf_uops, metadata, fixedvars, bound_ranges = si - fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges} - pre_schedule.append(ExecItem(ast, [], metadata, fixedvars)) - buf_uops_list.append(UOp.sink(*buf_uops)) - sched_ptr += 1 + with cpu_profile(TracingKey("unroll outer ranges")): + pre_schedule, buf_uops_list = unroll_outer_ranges(schedule) return pre_schedule, UOp.sink(*buf_uops_list) +def unroll_outer_ranges(schedule:list[ScheduleItem|UOp]) -> tuple[list[ExecItem], list[UOp]]: + pre_schedule: list[ExecItem] = [] + buf_uops_list: list[UOp] = [] + sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]() + while sched_ptr < len(schedule): + if isinstance(si := schedule[sched_ptr], UOp): + if si.op is Ops.RANGE: + in_ranges[si] = 0 + range_ptrs[si] = sched_ptr + 1 + elif si.op is Ops.END: + if in_ranges[si.src[1]] < si.src[1].vmax: + in_ranges[si.src[1]] += 1 + sched_ptr = range_ptrs[si.src[1]] + continue + else: + ast, buf_uops, metadata, _, bound_ranges = si + fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges} + pre_schedule.append(ExecItem(ast, [], metadata, fixedvars)) + buf_uops_list.append(UOp.sink(*buf_uops)) + sched_ptr += 1 + return pre_schedule, buf_uops_list + from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.multi import get_multi_map