bfs refactors from the big graph branch [pr] (#7235)

This commit is contained in:
qazal
2024-10-23 23:24:31 +03:00
committed by GitHub
parent ea11382087
commit 65bbafe3e2

View File

@@ -404,19 +404,18 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int)
for lsi in prescheduled:
if lsi not in in_degree: in_degree[lsi] = 0
# realize outputs after all parents are realized
scheduled_parents = dedup(schedule_targets[x] for x in lsi.inputs if x in schedule_targets)
for x in scheduled_parents:
graph[x].append(lsi)
in_degree[lsi] += 1
# realize outputs before a parent is assigned to
parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets)
for assign in parents_assigns:
graph[lsi].append(assign)
in_degree[assign] += 1
# realize outputs after all parents are realized
scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None)
for x in scheduled_parents:
graph[x].append(lsi)
in_degree[lsi] += 1
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
queue = deque(lsi for lsi in prescheduled if in_degree[lsi] == 0)
schedule: List[ScheduleItem] = []
while queue:
lsi = queue.popleft()