Fix scope order in graph toposort [run_process_replay] (#5330)

* fix

* test

* nothing
This commit is contained in:
kormann
2024-07-08 20:46:15 +02:00
committed by GitHub
parent 631bc974a0
commit 2349d837fb
2 changed files with 19 additions and 6 deletions

View File

@@ -221,5 +221,20 @@ class TestUOpGraph(TestUOps):
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
with self.assertRaises(AssertionError): uops.linearize()
def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(UOps.STORE, None, (glbl, alu, cf))
uops = UOpGraph([store]).uops
ranges = [x for x in uops if x.op is UOps.RANGE]
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -437,16 +437,14 @@ class UOpGraph:
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
self._uops.insert(idx, x)
else:
self._uops.append(x)
for u, ss in scope_children.items():
if x in ss:
ss.remove(x)
if len(ss) == 0: self._uops.append(UOp(end_for_uop[u.op][1], None, (u,)))
else: self._uops.append(x)
for u in children[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
for u in (self._uops):
if u.op in end_for_uop: self._uops.insert(max([self._uops.index(l) for l in scope_children[u]])+1, UOp(end_for_uop[u.op][1], None, (u,)))
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]