From e2b1f2620dd38d90fef56d3f72e2c9bbb19a0c8b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 24 Feb 2026 11:30:41 +0800 Subject: [PATCH] schedule is linear (#14975) * schedule is linear * cleanup * cleanups --- tinygrad/engine/schedule.py | 69 ++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index bffed14c1d..7fd4e361a5 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -15,7 +15,7 @@ def _unwrap_src(s: UOp) -> UOp: while len(s.src) and s.op not in {Ops.AFTER, Ops.BUFFER, Ops.PARAM, Ops.MSELECT, Ops.MSTACK, Ops.BIND}: s = s.src[0] return s -def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: +def create_schedule(sched_sink:UOp) -> UOp: with cpu_profile(TracingKey("toposort sched_sink")): # build kernel dependency graph: edges from producer kernel to consumer kernels children: dict[UOp, list[UOp]] = {} @@ -47,20 +47,17 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: with cpu_profile(TracingKey("linearize schedule")): queue: deque[UOp] = deque(k for k,v in in_degree.items() if v == 0) - pre_schedule: list[ExecItem] = [] - buf_uops_list: list[UOp] = [] + linearized: list[UOp] = [] while len(queue): rk = queue.popleft() k = rk.src[0] if rk.op is Ops.END else rk assert k.op is Ops.CALL, f"unexpected op in queue: {k.op}" buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND) - pre_schedule.append(ExecItem(k.src[0], [], k.arg.metadata)) - buf_uops_list.append(UOp.sink(*buf_uops)) + linearized.append(k.src[0].call(*buf_uops, metadata=k.arg.metadata)) for x in children.get(rk, []): in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) - - return pre_schedule, UOp.sink(*buf_uops_list) + return UOp(Ops.LINEAR, src=tuple(linearized)) from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_kernel_graph @@ -76,42 +73,55 @@ pm_post_sched_cache = PatternMatcher([ (UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), create_new_buffer), ]) -schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {} +schedule_cache: dict[bytes, UOp] = {} @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]: # big_sink srcs are all the Tensors st = time.perf_counter() big_sink, buffer_map = transform_to_call(big_sink) - function = big_sink.src[0] + if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None: if SPEC: type_verify(big_sink, tensor_spec) - pre_schedule, buf_uops_sink = create_schedule(get_kernel_graph(function)) - if SCACHE: schedule_cache[function.key] = (pre_schedule, buf_uops_sink) + linear = create_schedule(get_kernel_graph(function)) + if SCACHE: schedule_cache[function.key] = linear else: # schedule cache hit - pre_schedule, buf_uops_sink = sc_ret - # it's a call that we late apply - buf_uops_sink = graph_rewrite(buf_uops_sink, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") + linear = sc_ret - # add bufs to pre_schedule + # it's a call that we late apply + linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") + + # vars used in the schedule + used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src]) + # get var_vals + var_vals: dict[str, int] = {} + for b in big_sink.src[1:]: + if b.op is Ops.BIND: + nm = b.src[0].expr + if nm not in used_vars: continue + val = b.src[1].arg + assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" + var_vals[nm] = val + + # convert LINEAR to ExecItems schedule: list[ExecItem] = [] - for i, si in enumerate(pre_schedule): - buf_uops = buf_uops_sink.src[i].src + for si in linear.src: + ast, buf_uops = si.src[0], si.src[1:] # create subbuffers if needed - if si.ast.op is Ops.BUFFER_VIEW: + if ast.op is Ops.BUFFER_VIEW: base = buf_uops[1].buffer assert isinstance(base, Buffer), "base can't be MultiBuffer" - buffers[buf_uops[0]] = base.view(buf_uops[0].arg, si.ast.dtype, si.ast.arg[1]*base.dtype.itemsize) - ubufs = tuple(b.buffer for b in buf_uops) + buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize) + ubufs = [b.buffer for b in buf_uops] + metadata = si.arg.metadata if any(isinstance(x, MultiBuffer) for x in ubufs): assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" - dnums = [x for x in si.ast.variables() if x.expr == '_device_num'] + dnums = [x for x in ast.variables() if x.expr == '_device_num'] for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): - schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}))) + schedule.append(ExecItem(ast, list(bufs), metadata, {dnums[0].expr:j} if len(dnums) else {})) else: - # ONE -> ONE - schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars)) + schedule.append(ExecItem(ast, list(ubufs), metadata)) with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: @@ -124,15 +134,4 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}")) - # vars used in the schedule - used_vars = set().union(*[{v.expr for v in si.ast.variables()} for si in schedule]) - # get var_vals - var_vals: dict[str, int] = {} - for i,b in enumerate(big_sink.src[1:]): - if b.op is Ops.BIND: - nm = b.src[0].expr - if nm not in used_vars: continue - val = b.src[1].arg - assert nm not in var_vals or var_vals[nm] == val, f"bind mismatch on {nm}, {var_vals[nm]} != {val}" - var_vals[nm] = val return buffer_map, schedule, var_vals