From eea42d864f6d6ebe991651980a711ae32ea367d0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 8 Apr 2024 20:04:19 +0300 Subject: [PATCH] account for all outputs (#4113) --- tinygrad/engine/schedule.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index cb4b8eecba..26c3a9b71b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -197,18 +197,23 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) # preschedule all buffers in realizes prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} + schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs} assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None} # breadth first ordering graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int) for key, si in prescheduled.items(): - for x in si.inputs: + # realize outputs after all parents are realized + scheduled_parents = set(schedule_targets[x].outputs[0] for x in si.inputs if x in schedule_targets) + for x in scheduled_parents: graph[x].append(key) - if x in assign_targets: - graph[key].append(assign_targets[x]) - in_degree[assign_targets[x]] += 1 - if x in prescheduled: in_degree[key] += 1 + in_degree[key] += 1 + # realize outputs before a parent is assigned to + parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in si.inputs if x in assign_targets) + for assign in parents_assigns: + graph[key].append(assign) + in_degree[assign] += 1 for out in si.outputs: del out.srcs # can only schedule once queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)