From 1ea8fcbe1bc65b6e014c1f89c80fc36ee99d507c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 3 Apr 2024 18:52:37 +0300 Subject: [PATCH] graph schedule items (#4054) --- tinygrad/engine/schedule.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9151e01687..68e62b5108 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -199,30 +199,29 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) # breadth first ordering graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int) - for out, si in prescheduled.items(): + for key, si in prescheduled.items(): for x in si.inputs: - graph[x].append(out) + graph[x].append(key) if x in assign_targets: - graph[out].append(assign_targets[x]) + graph[key].append(assign_targets[x]) in_degree[assign_targets[x]] += 1 - if x in prescheduled: in_degree[out] += 1 - del out.srcs # can only schedule once + if x in prescheduled: in_degree[key] += 1 + for out in si.outputs: del out.srcs # can only schedule once - queue = deque(out for out in prescheduled if in_degree[out] == 0) + queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0) schedule: List[ScheduleItem] = [] kernel_number = GlobalCounters.kernel_count while queue: - buf = queue.popleft() - seen.add(buf) - ps = prescheduled[buf] + ps = queue.popleft() + for buf in ps.outputs: seen.add(buf) if GRAPH: kernel_number += 1 for out in ps.outputs: realized_lazybuffer(out, kernel_number) schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals)) - for x in graph[buf]: + for x in graph[ps.outputs[0]]: in_degree[x] -= 1 - if in_degree[x] == 0: queue.append(x) + if in_degree[x] == 0: queue.append(prescheduled[x]) # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):