diff --git a/test/unit/test_block_reorder.py b/test/unit/test_block_reorder.py new file mode 100644 index 0000000000..e195318c87 --- /dev/null +++ b/test/unit/test_block_reorder.py @@ -0,0 +1,50 @@ +import unittest, random +from tinygrad.dtype import dtypes +from tinygrad.ops import print_uops, UOp, Ops +from tinygrad.codegen.linearize import block_reorder +#from tinygrad.renderer.cstyle import ClangRenderer + +def is_toposorted(lst:list[UOp]): + seen = set() + for u in lst: + if any(p not in seen for p in u.src): return False + seen.add(u) + return True + +class TestBlockReorder(unittest.TestCase): + def test_loads_randomize(self): + a = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=0) + b = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=1) + c = UOp(Ops.DEFINE_GLOBAL, dtype=dtypes.float.ptr(), arg=2) + v1 = UOp(Ops.DEFINE_VAR, dtype=dtypes.int, arg=("a",)) + v2 = UOp(Ops.DEFINE_VAR, dtype=dtypes.int, arg=("b",)) + sink = c.store(sum([ + a.index(v1).load(dtype=dtypes.float), + a.index(v1+1).load(dtype=dtypes.float), + a.index(v1+2).load(dtype=dtypes.float), + b.index(v2).load(dtype=dtypes.float), + b.index(v2+1).load(dtype=dtypes.float), + b.index(v2+2).load(dtype=dtypes.float), + ])).sink() + + golden = block_reorder(sink.toposort) + + # test random order is always same + for _ in range(50): + # shuffle and form a valid toposort + lst = golden[:] + random.shuffle(lst) + topolst = [] + for u in lst: + for p in u.toposort: + if p not in topolst: topolst.append(p) + assert is_toposorted(topolst) + + for x,y in zip(golden, this_order:=block_reorder(topolst)): + if x is not y: + print_uops(golden) + print_uops(this_order) + self.assertIs(x, y) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index e53d7fea07..1a0590c8c9 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -11,11 +11,12 @@ from tinygrad.spec import type_verify def block_reorder(lst:list[UOp]) -> list[UOp]: in_this_block = set(lst) local_children: defaultdict[UOp, list[UOp]] = defaultdict(list) - in_degree: defaultdict[UOp, int] = defaultdict(int) + in_degree:dict[UOp, int] = {} priorities:dict[UOp, int] = {} # get local children and assign priorities for u in reversed(lst): + in_degree[u] = 0 for s in u.src: if s in in_this_block: local_children[s].append(u) @@ -26,21 +27,18 @@ def block_reorder(lst:list[UOp]) -> list[UOp]: if u.op is Ops.BARRIER: priority.append(-1500) priorities[u] = min(priority) - # placement queue - queue:list[tuple[int, tuple, UOp]] = [] - def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) - - # place the first ones that don't have deps - for u in lst: - if u not in in_degree: push(u) + # number the uops in "ideal" order + nkey = {u:i for i,u in enumerate(sorted(lst, key=lambda x: (priorities[x], x.tuplize)))} + # then force then to be toposorted in as close to the ideal order as possible + heapq.heapify(heap:=[(nkey[u],u) for u in lst if in_degree[u] == 0]) newlst = [] - while queue: - _,_,x = heapq.heappop(queue) - newlst.append(x) - for u in local_children[x]: - in_degree[u] -= 1 - if in_degree[u] == 0: push(u) + while heap: + _,u = heapq.heappop(heap) + newlst.append(u) + for v in local_children[u]: + in_degree[v] -= 1 + if in_degree[v] == 0: heapq.heappush(heap, (nkey[v],v)) assert len(newlst) == len(lst), f"len mismatch {len(newlst)} != {len(lst)}" return newlst