minor linearizer refactor to finalize in rewrite [pr] (#9148)

This commit is contained in:
George Hotz
2025-02-18 12:42:22 +08:00
committed by GitHub
parent df3b320f46
commit 2db8b4046a

View File

@@ -112,6 +112,16 @@ def block_merge(ctx, x:UOp):
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
def block_finalize(block:UOp):
if len(block.src) == 0: return None
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
_uops += block.arg.lst
assert _uops[-1].op is Ops.SINK, "block doesn't end with SINK"
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
def block_reorder(in_block:UOp):
in_this_block = set(in_block.arg.lst)
@@ -212,14 +222,11 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
# final rewrite to merge all blocks into one
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
# there should just be one block left, with a few parents with 0 srcs
assert sink.op is Ops.BLOCK
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
_uops += sink.arg.lst
# there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
sink = graph_rewrite(sink, pm_block_finalize)
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
if not skip_check: type_verify(_uops)
if not skip_check: type_verify(sink.arg.lst)
# strip the SINK
return _uops[:-1]
# return the list. TODO: refactor to return the UOp
return list(sink.arg.lst)