mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
expand ranges -> unroll outer ranges [pr] (#14440)
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
# ScheduleItem = tuple[AST, buffer UOps, metadata, fixedvars, bound_ranges]
|
||||
ScheduleItem = tuple[UOp, tuple[UOp, ...], tuple[Metadata, ...], dict[str, int], tuple[UOp, ...]]
|
||||
|
||||
# unwrap VIEW/CAST/etc to find the actual data source (kernel output, buffer, or multi-device op)
|
||||
def _unwrap_src(s: UOp) -> UOp:
|
||||
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
|
||||
@@ -48,48 +51,48 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
for k,v in in_degree.items():
|
||||
if v == 0: queue.append(k)
|
||||
|
||||
schedule: list[tuple|UOp] = []
|
||||
schedule: list[ScheduleItem|UOp] = [] # ScheduleItem for kernels, UOp for RANGE/END
|
||||
while len(queue):
|
||||
k = rk = queue.popleft()
|
||||
if k.op is Ops.END: k = k.src[0]
|
||||
assert k.op in {Ops.RANGE, Ops.KERNEL}, f"unexpected op in queue: {k.op}"
|
||||
if k.op is Ops.RANGE: schedule.append(k)
|
||||
elif k.op is Ops.KERNEL:
|
||||
ast = k.arg.ast
|
||||
ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
schedule.append((ast, buf_uops, k.arg.metadata, {}, bound_ranges))
|
||||
schedule.append((ast, buf_uops, kernel.metadata, {}, bound_ranges))
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children.get(rk, []):
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
with cpu_profile(TracingKey("expand ranges")):
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr = 0
|
||||
in_ranges: dict[UOp, int] = {}
|
||||
range_ptrs: dict[UOp, int] = {}
|
||||
while sched_ptr < len(schedule):
|
||||
si = schedule[sched_ptr]
|
||||
if isinstance(si, UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
ast, buf_uops, metadata, fixedvars, bound_ranges = si
|
||||
fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
with cpu_profile(TracingKey("unroll outer ranges")):
|
||||
pre_schedule, buf_uops_list = unroll_outer_ranges(schedule)
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
|
||||
def unroll_outer_ranges(schedule:list[ScheduleItem|UOp]) -> tuple[list[ExecItem], list[UOp]]:
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr, in_ranges, range_ptrs = 0, dict[UOp, int](), dict[UOp, int]()
|
||||
while sched_ptr < len(schedule):
|
||||
if isinstance(si := schedule[sched_ptr], UOp):
|
||||
if si.op is Ops.RANGE:
|
||||
in_ranges[si] = 0
|
||||
range_ptrs[si] = sched_ptr + 1
|
||||
elif si.op is Ops.END:
|
||||
if in_ranges[si.src[1]] < si.src[1].vmax:
|
||||
in_ranges[si.src[1]] += 1
|
||||
sched_ptr = range_ptrs[si.src[1]]
|
||||
continue
|
||||
else:
|
||||
ast, buf_uops, metadata, _, bound_ranges = si
|
||||
fixedvars = {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
return pre_schedule, buf_uops_list
|
||||
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
Reference in New Issue
Block a user