diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 00747f789c..5f99e16f30 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -404,19 +404,18 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list) in_degree: DefaultDict[LBScheduleItem, int] = defaultdict(int) for lsi in prescheduled: - if lsi not in in_degree: in_degree[lsi] = 0 - # realize outputs after all parents are realized - scheduled_parents = dedup(schedule_targets[x] for x in lsi.inputs if x in schedule_targets) - for x in scheduled_parents: - graph[x].append(lsi) - in_degree[lsi] += 1 # realize outputs before a parent is assigned to parents_assigns = dedup(schedule_targets[assign_targets[x]] for x in lsi.inputs if x in assign_targets) for assign in parents_assigns: graph[lsi].append(assign) in_degree[assign] += 1 + # realize outputs after all parents are realized + scheduled_parents = dedup(xsi for x in lsi.inputs if (xsi:=schedule_targets.get(x)) is not None) + for x in scheduled_parents: + graph[x].append(lsi) + in_degree[lsi] += 1 - queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0) + queue = deque(lsi for lsi in prescheduled if in_degree[lsi] == 0) schedule: List[ScheduleItem] = [] while queue: lsi = queue.popleft()