move buffer refcount increment to the toposort [pr] (#9081)

This commit is contained in:
qazal
2025-02-14 13:54:22 +02:00
committed by GitHub
parent 73af42aeab
commit 65297066c2

View File

@@ -440,7 +440,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
for tensor_uop in buf_tensors[buf_uop]:
# ASSIGN just becomes the buffer in source, otherwise we reshape the buffer
becomes_map[tensor_uop] = tensor_uop.src[0] if tensor_uop.op is Ops.ASSIGN else buf_uop.reshape(tensor_uop.shape)
buf_uop.buffer.ref(1)
# create kernels, TODO: this should use the SINK from tensor_map
graph_rewrite(sink, break_sched, ctx)
@@ -475,6 +474,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
schedule: list[ScheduleItem] = []
while queue:
schedule.append(si:=queue.popleft())
# NOTE: incrementing output buffer refcounts is required by the memory planner and JIT
for out in si.outputs: out.ref(1)
for x in graph[si]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)