graph schedule items (#4054)

This commit is contained in:
qazal
2024-04-03 18:52:37 +03:00
committed by GitHub
parent 52ee5b73b2
commit 1ea8fcbe1b

View File

@@ -199,30 +199,29 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
# breadth first ordering
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for out, si in prescheduled.items():
for key, si in prescheduled.items():
for x in si.inputs:
graph[x].append(out)
graph[x].append(key)
if x in assign_targets:
graph[out].append(assign_targets[x])
graph[key].append(assign_targets[x])
in_degree[assign_targets[x]] += 1
if x in prescheduled: in_degree[out] += 1
del out.srcs # can only schedule once
if x in prescheduled: in_degree[key] += 1
for out in si.outputs: del out.srcs # can only schedule once
queue = deque(out for out in prescheduled if in_degree[out] == 0)
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
schedule: List[ScheduleItem] = []
kernel_number = GlobalCounters.kernel_count
while queue:
buf = queue.popleft()
seen.add(buf)
ps = prescheduled[buf]
ps = queue.popleft()
for buf in ps.outputs: seen.add(buf)
if GRAPH:
kernel_number += 1
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0),
tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals))
for x in graph[buf]:
for x in graph[ps.outputs[0]]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
if in_degree[x] == 0: queue.append(prescheduled[x])
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):