diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index d16207d93f..28dd0f1f25 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -583,20 +583,26 @@ class UOpGraph: for u in children: if in_degree[u] == 0: push(u) + scope_end: Dict[UOp, UOp] = {} self._uops = [] while queue: p,x = heapq.heappop(queue) if DEBUG >= 7: print(p,x) + if x in scope_children: scope_end[x] = x if x.op is UOps.DEFINE_ACC: idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE]) self._uops.insert(idx, x) else: self._uops.append(x) + for u, ss in scope_children.items(): + if x in ss: + ss.remove(x) + if len(ss) == 0: scope_end[u] = x for u in children[x]: in_degree[u] -= 1 if in_degree[u] == 0: push(u) - for u in (self._uops): - if u.op in END_FOR_UOP: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(END_FOR_UOP[u.op][1], None, (u,))) + # end scopes in toposort order + for u, x in scope_end.items(): self._uops.insert(self._uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,))) # sanity checks (NOTE: these can cause things to be skipped in BEAM) try: