account for all outputs (#4113)

This commit is contained in:
qazal
2024-04-08 20:04:19 +03:00
committed by GitHub
parent dbd39ab78a
commit eea42d864f

View File

@@ -197,18 +197,23 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
# preschedule all buffers in realizes
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
# breadth first ordering
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for key, si in prescheduled.items():
for x in si.inputs:
# realize outputs after all parents are realized
scheduled_parents = set(schedule_targets[x].outputs[0] for x in si.inputs if x in schedule_targets)
for x in scheduled_parents:
graph[x].append(key)
if x in assign_targets:
graph[key].append(assign_targets[x])
in_degree[assign_targets[x]] += 1
if x in prescheduled: in_degree[key] += 1
in_degree[key] += 1
# realize outputs before a parent is assigned to
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in si.inputs if x in assign_targets)
for assign in parents_assigns:
graph[key].append(assign)
in_degree[assign] += 1
for out in si.outputs: del out.srcs # can only schedule once
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)