schedule is linear (#14975)

* schedule is linear

* cleanup

* cleanups
This commit is contained in:
George Hotz
2026-02-24 11:30:41 +08:00
committed by GitHub
parent 57ade7608a
commit e2b1f2620d

View File

@@ -15,7 +15,7 @@ def _unwrap_src(s: UOp) -> UOp:
while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0]
return s
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
def create_schedule(sched_sink:UOp) -> UOp:
with cpu_profile(TracingKey("toposort sched_sink")):
# build kernel dependency graph: edges from producer kernel to consumer kernels
children: dict[UOp, list[UOp]] = {}
@@ -47,20 +47,17 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("linearize schedule")):
queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0)
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
linearized: list[UOp] = []
while len(queue):
rk = queue.popleft()
k = rk.src[0] if rk.op is Ops.END else rk
assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}"
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata))
buf_uops_list.append(UOp.sink(*buf_uops))
linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata))
for x in children.get(rk, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
return pre_schedule, UOp.sink(*buf_uops_list)
return UOp(Ops.LINEAR, src=tuple(linearized))
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_kernel_graph
@@ -76,42 +73,55 @@ pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer),
])
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
schedule_cache: dict[bytes, UOp] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
big_sink, buffer_map = transform_to_call(big_sink)
function = big_sink.src[0]
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
pre_schedule, buf_uops_sink = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = (pre_schedule, buf_uops_sink)
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = linear
else:
# schedule cache hit
pre_schedule, buf_uops_sink = sc_ret
# it's a call that we late apply
buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
linear = sc_ret
# add bufs to pre_schedule
# it's a call that we late apply
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
# get var_vals
var_vals: dict[str, int] = {}
for b in big_sink.src[1:]:
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
# convert LINEAR to ExecItems
schedule: list[ExecItem] = []
for i, si in enumerate(pre_schedule):
buf_uops = buf_uops_sink.src[i].src
for si in linear.src:
ast, buf_uops = si.src[0], si.src[1:]
# create subbuffers if needed
if si.ast.op is Ops.BUFFER_VIEW:
if ast.op is Ops.BUFFER_VIEW:
base = buf_uops[1].buffer
assert isinstance(base, Buffer), "base can't be MultiBuffer"
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, si.ast.dtype, si.ast.arg[1]*base.dtype.itemsize)
ubufs = tuple(b.buffer for b in buf_uops)
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = [b.buffer for b in buf_uops]
metadata = si.arg.metadata
if any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.expr == '_device_num']
dnums = [x for x in ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {}))
else:
# ONE -> ONE
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
schedule.append(ExecItem(ast, list(ubufs), metadata))
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
@@ -124,15 +134,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
# vars used in the schedule
used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule])
# get var_vals
var_vals: dict[str, int] = {}
for i,b in enumerate(big_sink.src[1:]):
if b.op is Ops.BIND:
nm = b.src[0].expr
if nm not in used_vars: continue
val = b.src[1].arg
assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}"
var_vals[nm] = val
return buffer_map, schedule, var_vals