schedule linearize small cleanups [pr] (#9994)

This commit is contained in:
qazal
2025-04-23 00:42:29 +03:00
committed by GitHub
parent f4ec57baff
commit 58180caad3

View File

@@ -52,10 +52,11 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
var_vals: dict[Variable, int] = {}
while queue:
k = queue.popleft()
# map the BUFFER UOp to a subbuffer if it's a BUFFER_VIEW
if k.arg.ast.op is Ops.BUFFER_VIEW:
buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, k.arg.ast.dtype, k.arg.ast.arg[1]*base.dtype.itemsize)
schedule.append(ScheduleItem(graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals), tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
# unbind var_vals from the kernel
ast = graph_rewrite(k.arg.ast, pm_unbind, ctx=var_vals)
# create subbuffers if needed
if ast.op is Ops.BUFFER_VIEW: buffers[k.src[0]] = (base:=k.src[1].buf_uop.buffer).view(k.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
schedule.append(ScheduleItem(ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
for x in children[k]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)